[tint] Shuffle template generation code
Add helpers for loading templates from files. Include the file path in
the template - helpers with errors.
Move the intrinsic generation code out to a sub-package.
Use the subcmd package to allow for more future commands.
Change-Id: I909b654a2930f749b2a67ae29c3d1e90296c0523
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/146382
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
diff --git a/tools/src/cmd/gen/common/cmds.go b/tools/src/cmd/gen/common/cmds.go
new file mode 100644
index 0000000..a6103ba
--- /dev/null
+++ b/tools/src/cmd/gen/common/cmds.go
@@ -0,0 +1,31 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import (
+ "dawn.googlesource.com/dawn/tools/src/subcmd"
+)
+
+// The registered commands
+var commands []Command
+
+// Command is the type of a single gen command
+type Command = subcmd.Command[*Config]
+
+// Register registers the command for use by the 'gen' tool
+func Register(c Command) { commands = append(commands, c) }
+
+// Commands returns all the commands registered
+func Commands() []Command { return commands }
diff --git a/tools/src/cmd/gen/common/config.go b/tools/src/cmd/gen/common/config.go
new file mode 100644
index 0000000..20a4c70
--- /dev/null
+++ b/tools/src/cmd/gen/common/config.go
@@ -0,0 +1,33 @@
+// Copyright 2023 The Dawn Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import "flag"
+
+// Config hold the common configuration between sub-commands
+type Config struct {
+ // Common command line flags
+ Flags struct {
+ // Emit additional logging
+ Verbose bool
+ // Don't emit anything, just check that files are up to date
+ CheckStale bool
+ }
+}
+
+func (c *Config) RegisterFlags() {
+ flag.BoolVar(&c.Flags.Verbose, "v", false, "print verbose output")
+ flag.BoolVar(&c.Flags.CheckStale, "check-stale", false, "don't emit anything, just check that files are up to date")
+}
diff --git a/tools/src/cmd/gen/common/header.go b/tools/src/cmd/gen/common/header.go
new file mode 100644
index 0000000..850f07b
--- /dev/null
+++ b/tools/src/cmd/gen/common/header.go
@@ -0,0 +1,68 @@
+// Copyright 2023 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http:•www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package common
+
+import (
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+)
+
+var re = regexp.MustCompile(`• Copyright (\d+) The`)
+
+const header = `• Copyright %v The Tint Authors.
+•
+• Licensed under the Apache License, Version 2.0 (the "License");
+• you may not use this file except in compliance with the License.
+• You may obtain a copy of the License at
+•
+• http://www.apache.org/licenses/LICENSE-2.0
+•
+• Unless required by applicable law or agreed to in writing, software
+• distributed under the License is distributed on an "AS IS" BASIS,
+• WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+• See the License for the specific language governing permissions and
+• limitations under the License.
+
+‣
+• File generated by tools/src/cmd/gen
+• using the template:
+• %v
+•
+• Do not modify this file directly
+‣
+`
+
+func Header(existing, templatePath, comment string) string {
+ copyrightYear := time.Now().Year()
+
+ // Replace comment characters with '•'
+ existing = strings.ReplaceAll(existing, comment, "•")
+
+ // Look for the existing copyright year
+ if match := re.FindStringSubmatch(string(existing)); len(match) == 2 {
+ if year, err := strconv.Atoi(match[1]); err == nil {
+ copyrightYear = year
+ }
+ }
+
+ // Replace '•' with comment characters, '‣' with a line of comment characters
+ out := strings.ReplaceAll(header, "•", comment)
+ out = strings.ReplaceAll(out, "‣", strings.Repeat(comment, 80/len(comment)))
+
+ return fmt.Sprintf(out, copyrightYear, templatePath)
+}
diff --git a/tools/src/cmd/gen/main.go b/tools/src/cmd/gen/main.go
index 3578796..926e210 100644
--- a/tools/src/cmd/gen/main.go
+++ b/tools/src/cmd/gen/main.go
@@ -12,466 +12,64 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// gen scans the the project directory for '<file>.tmpl' files, producing code
-// from those template files.
+// gen generates code for the Tint project.
package main
import (
- "flag"
+ "context"
"fmt"
- "io"
- "io/ioutil"
- "math/rand"
"os"
- "os/exec"
- "path/filepath"
- "reflect"
- "regexp"
- "runtime"
- "strconv"
"strings"
- "time"
- "dawn.googlesource.com/dawn/tools/src/container"
- "dawn.googlesource.com/dawn/tools/src/fileutils"
- "dawn.googlesource.com/dawn/tools/src/glob"
- "dawn.googlesource.com/dawn/tools/src/template"
- "dawn.googlesource.com/dawn/tools/src/tint/intrinsic/gen"
- "dawn.googlesource.com/dawn/tools/src/tint/intrinsic/parser"
- "dawn.googlesource.com/dawn/tools/src/tint/intrinsic/resolver"
- "dawn.googlesource.com/dawn/tools/src/tint/intrinsic/sem"
+ "dawn.googlesource.com/dawn/tools/src/cmd/gen/common"
+ "dawn.googlesource.com/dawn/tools/src/cmd/gen/templates"
+ "dawn.googlesource.com/dawn/tools/src/subcmd"
+
+ // Register sub-commands
+ _ "dawn.googlesource.com/dawn/tools/src/cmd/gen/templates"
)
func main() {
- if err := run(); err != nil {
- fmt.Println(err)
+ ctx := context.Background()
+
+ if len(os.Args) == 1 || strings.HasPrefix(os.Args[1], "-") {
+ os.Args = append([]string{os.Args[0], "all"}, os.Args[1:]...)
+ }
+
+ cfg := &common.Config{}
+ cfg.RegisterFlags()
+
+ if err := subcmd.Run(ctx, cfg, common.Commands()...); err != nil {
+ if err != subcmd.ErrInvalidCLA {
+ fmt.Fprintln(os.Stderr, err)
+ }
os.Exit(1)
}
}
-func showUsage() {
- fmt.Println(`
-gen generates the templated code for the Tint compiler
-
-gen accepts a list of file paths to the templates to generate. If no templates
-are explicitly specified, then gen scans the '<dawn>/src/tint' and
-'<dawn>/test/tint' directories for '<file>.tmpl' files.
-
-usage:
- gen [flags] [template files]
-
-optional flags:`)
- flag.PrintDefaults()
- fmt.Println(``)
- os.Exit(1)
+func init() {
+ common.Register(&cmdAll{})
}
-func run() error {
- outputDir := ""
- verbose := false
- checkStale := false
- flag.StringVar(&outputDir, "o", "", "custom output directory (optional)")
- flag.BoolVar(&verbose, "verbose", false, "print verbose output")
- flag.BoolVar(&checkStale, "check-stale", false, "don't emit anything, just check that files are up to date")
- flag.Parse()
+type cmdAll struct {
+}
- staleFiles := []string{}
- projectRoot := fileutils.DawnRoot()
+func (cmdAll) Name() string {
+ return "all"
+}
- // Find clang-format
- clangFormatPath := findClangFormat(projectRoot)
- if clangFormatPath == "" {
- return fmt.Errorf("cannot find clang-format in <dawn>/buildtools nor PATH")
+func (cmdAll) Desc() string {
+ return `all runs all the generators`
+}
+
+func (c *cmdAll) RegisterFlags(ctx context.Context, cfg *common.Config) ([]string, error) {
+ return nil, nil
+}
+
+func (c cmdAll) Run(ctx context.Context, cfg *common.Config) error {
+ templatesCmd := templates.Cmd{}
+ if err := templatesCmd.Run(ctx, cfg); err != nil {
+ return err
}
-
- files := flag.Args()
- if len(files) == 0 {
- // Recursively find all the template files in the <dawn>/src/tint and
- // <dawn>/test/tint and directories
- var err error
- files, err = glob.Scan(projectRoot, glob.MustParseConfig(`{
- "paths": [{"include": [
- "src/tint/**.tmpl",
- "test/tint/**.tmpl"
- ]}]
- }`))
- if err != nil {
- return err
- }
- } else {
- // Make all template file paths project-relative
- for i, f := range files {
- abs, err := filepath.Abs(f)
- if err != nil {
- return fmt.Errorf("failed to get absolute file path for '%v': %w", f, err)
- }
- if !strings.HasPrefix(abs, projectRoot) {
- return fmt.Errorf("template '%v' is not under project root '%v'", abs, projectRoot)
- }
- rel, err := filepath.Rel(projectRoot, abs)
- if err != nil {
- return fmt.Errorf("failed to get project relative file path for '%v': %w", f, err)
- }
- files[i] = rel
- }
- }
-
- cache := &genCache{}
-
- // For each template file...
- for _, relTmplPath := range files { // relative to project root
- if verbose {
- fmt.Println("processing", relTmplPath)
- }
- // Make tmplPath absolute
- tmplPath := filepath.Join(projectRoot, relTmplPath)
- tmplDir := filepath.Dir(tmplPath)
-
- // Read the template file
- tmpl, err := ioutil.ReadFile(tmplPath)
- if err != nil {
- return fmt.Errorf("failed to open '%v': %w", tmplPath, err)
- }
-
- // Create or update the file at relPath if the file content has changed,
- // preserving the copyright year in the header.
- // relPath is a path relative to the template
- writeFile := func(relPath, body string) error {
- var outPath string
- if outputDir != "" {
- relTmplDir := filepath.Dir(relTmplPath)
- outPath = filepath.Join(outputDir, relTmplDir, relPath)
- } else {
- outPath = filepath.Join(tmplDir, relPath)
- }
-
- copyrightYear := time.Now().Year()
-
- // Load the old file
- existing, err := ioutil.ReadFile(outPath)
- if err == nil {
- // Look for the existing copyright year
- if match := copyrightRegex.FindStringSubmatch(string(existing)); len(match) == 2 {
- if year, err := strconv.Atoi(match[1]); err == nil {
- copyrightYear = year
- }
- }
- }
-
- // Write the common file header
- if verbose {
- fmt.Println(" writing", outPath)
- }
- sb := strings.Builder{}
- sb.WriteString(fmt.Sprintf(header, copyrightYear, filepath.ToSlash(relTmplPath)))
- sb.WriteString(body)
- oldContent, newContent := string(existing), sb.String()
-
- if oldContent != newContent {
- if checkStale {
- staleFiles = append(staleFiles, outPath)
- } else {
- if err := os.MkdirAll(filepath.Dir(outPath), 0777); err != nil {
- return fmt.Errorf("failed to create directory for '%v': %w", outPath, err)
- }
- if err := ioutil.WriteFile(outPath, []byte(newContent), 0666); err != nil {
- return fmt.Errorf("failed to write file '%v': %w", outPath, err)
- }
- }
- }
-
- return nil
- }
-
- // Write the content generated using the template and semantic info
- sb := strings.Builder{}
- if err := generate(string(tmpl), cache, &sb, writeFile); err != nil {
- return fmt.Errorf("while processing '%v': %w", tmplPath, err)
- }
-
- if body := sb.String(); body != "" {
- _, tmplFileName := filepath.Split(tmplPath)
- outFileName := strings.TrimSuffix(tmplFileName, ".tmpl")
-
- switch filepath.Ext(outFileName) {
- case ".cc", ".h", ".inl":
- body, err = clangFormat(body, clangFormatPath)
- if err != nil {
- return err
- }
- }
-
- if err := writeFile(outFileName, body); err != nil {
- return err
- }
- }
- }
-
- if len(staleFiles) > 0 {
- fmt.Println(len(staleFiles), "files need regenerating:")
- for _, path := range staleFiles {
- if rel, err := filepath.Rel(projectRoot, path); err == nil {
- fmt.Println(" •", rel)
- } else {
- fmt.Println(" •", path)
- }
- }
- fmt.Println("Regenerate these files with: ./tools/run gen")
- os.Exit(1)
- }
-
return nil
}
-
-type intrinsicCache struct {
- path string
- cachedSem *sem.Sem // lazily built by sem()
- cachedTable *gen.IntrinsicTable // lazily built by intrinsicTable()
- cachedPermuter *gen.Permuter // lazily built by permute()
-}
-
-// Sem lazily parses and resolves the intrinsic.def file, returning the semantic info.
-func (i *intrinsicCache) Sem() (*sem.Sem, error) {
- if i.cachedSem == nil {
- // Load the intrinsic definition file
- defPath := filepath.Join(fileutils.DawnRoot(), i.path)
-
- defSource, err := os.ReadFile(defPath)
- if err != nil {
- return nil, err
- }
-
- // Parse the definition file to produce an AST
- ast, err := parser.Parse(string(defSource), i.path)
- if err != nil {
- return nil, err
- }
-
- // Resolve the AST to produce the semantic info
- sem, err := resolver.Resolve(ast)
- if err != nil {
- return nil, err
- }
-
- i.cachedSem = sem
- }
- return i.cachedSem, nil
-}
-
-// Table lazily calls and returns the result of BuildIntrinsicTable(),
-// caching the result for repeated calls.
-func (i *intrinsicCache) Table() (*gen.IntrinsicTable, error) {
- if i.cachedTable == nil {
- sem, err := i.Sem()
- if err != nil {
- return nil, err
- }
- i.cachedTable, err = gen.BuildIntrinsicTable(sem)
- if err != nil {
- return nil, err
- }
- }
- return i.cachedTable, nil
-}
-
-// Permute lazily calls NewPermuter(), caching the result for repeated calls,
-// then passes the argument to Permutator.Permute()
-func (i *intrinsicCache) Permute(overload *sem.Overload) ([]gen.Permutation, error) {
- if i.cachedPermuter == nil {
- sem, err := i.Sem()
- if err != nil {
- return nil, err
- }
- i.cachedPermuter, err = gen.NewPermuter(sem)
- if err != nil {
- return nil, err
- }
- }
- return i.cachedPermuter.Permute(overload)
-}
-
-// Cache for objects that are expensive to build, and can be reused between templates.
-type genCache struct {
- intrinsicsCache container.Map[string, *intrinsicCache]
-}
-
-func (g *genCache) intrinsics(path string) *intrinsicCache {
- if g.intrinsicsCache == nil {
- g.intrinsicsCache = container.NewMap[string, *intrinsicCache]()
- }
- i := g.intrinsicsCache[path]
- if i == nil {
- i = &intrinsicCache{path: path}
- g.intrinsicsCache[path] = i
- }
- return i
-}
-
-var copyrightRegex = regexp.MustCompile(`// Copyright (\d+) The`)
-
-const header = `// Copyright %v The Tint Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-////////////////////////////////////////////////////////////////////////////////
-// File generated by tools/src/cmd/gen
-// using the template:
-// %v
-//
-// Do not modify this file directly
-////////////////////////////////////////////////////////////////////////////////
-
-`
-
-type generator struct {
- cache *genCache
- writeFile WriteFile
- rnd *rand.Rand
-}
-
-// WriteFile is a function that Generate() may call to emit a new file from a
-// template.
-// relPath is the relative path from the currently executing template.
-// content is the file content to write.
-type WriteFile func(relPath, content string) error
-
-// generate executes the template tmpl, writing the output to w.
-// See https://golang.org/pkg/text/template/ for documentation on the template
-// syntax.
-func generate(tmpl string, cache *genCache, w io.Writer, writeFile WriteFile) error {
- g := generator{
- cache: cache,
- writeFile: writeFile,
- rnd: rand.New(rand.NewSource(4561123)),
- }
-
- funcs := map[string]interface{}{
- "SplitDisplayName": gen.SplitDisplayName,
- "Scramble": g.scramble,
- "IsEnumEntry": is(sem.EnumEntry{}),
- "IsEnumMatcher": is(sem.EnumMatcher{}),
- "IsFQN": is(sem.FullyQualifiedName{}),
- "IsInt": is(1),
- "IsTemplateEnumParam": is(sem.TemplateEnumParam{}),
- "IsTemplateNumberParam": is(sem.TemplateNumberParam{}),
- "IsTemplateTypeParam": is(sem.TemplateTypeParam{}),
- "IsType": is(sem.Type{}),
- "ElementType": gen.ElementType,
- "DeepestElementType": gen.DeepestElementType,
- "IsAbstract": gen.IsAbstract,
- "IsDeclarable": gen.IsDeclarable,
- "IsHostShareable": gen.IsHostShareable,
- "OverloadUsesF16": gen.OverloadUsesF16,
- "OverloadUsesReadWriteStorageTexture": gen.OverloadUsesReadWriteStorageTexture,
- "IsFirstIn": isFirstIn,
- "IsLastIn": isLastIn,
- "LoadIntrinsics": func(path string) *intrinsicCache { return g.cache.intrinsics(path) },
- "WriteFile": func(relPath, content string) (string, error) { return "", g.writeFile(relPath, content) },
- }
- return template.Run(tmpl, w, funcs)
-}
-
-// scramble randomly modifies the input string so that it is no longer equal to
-// any of the strings in 'avoid'.
-func (g *generator) scramble(str string, avoid container.Set[string]) (string, error) {
- bytes := []byte(str)
- passes := g.rnd.Intn(5) + 1
-
- const chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz"
-
- char := func() byte { return chars[g.rnd.Intn(len(chars))] }
- replace := func(at int) { bytes[at] = char() }
- delete := func(at int) { bytes = append(bytes[:at], bytes[at+1:]...) }
- insert := func(at int) { bytes = append(append(bytes[:at], char()), bytes[at:]...) }
-
- for i := 0; i < passes || avoid.Contains(string(bytes)); i++ {
- if len(bytes) > 0 {
- at := g.rnd.Intn(len(bytes))
- switch g.rnd.Intn(3) {
- case 0:
- replace(at)
- case 1:
- delete(at)
- case 2:
- insert(at)
- }
- } else {
- insert(0)
- }
- }
- return string(bytes), nil
-}
-
-// is returns a function that returns true if the value passed to the function
-// matches the type of 'ty'.
-func is(ty interface{}) func(interface{}) bool {
- rty := reflect.TypeOf(ty)
- return func(v interface{}) bool {
- ty := reflect.TypeOf(v)
- return ty == rty || ty == reflect.PtrTo(rty)
- }
-}
-
-// isFirstIn returns true if v is the first element of the given slice.
-func isFirstIn(v, slice interface{}) bool {
- s := reflect.ValueOf(slice)
- count := s.Len()
- if count == 0 {
- return false
- }
- return s.Index(0).Interface() == v
-}
-
-// isFirstIn returns true if v is the last element of the given slice.
-func isLastIn(v, slice interface{}) bool {
- s := reflect.ValueOf(slice)
- count := s.Len()
- if count == 0 {
- return false
- }
- return s.Index(count-1).Interface() == v
-}
-
-// Invokes the clang-format executable at 'exe' to format the file content 'in'.
-// Returns the formatted file.
-func clangFormat(in, exe string) (string, error) {
- cmd := exec.Command(exe)
- cmd.Stdin = strings.NewReader(in)
- out, err := cmd.CombinedOutput()
- if err != nil {
- return "", fmt.Errorf("clang-format failed:\n%v\n%v", string(out), err)
- }
- return string(out), nil
-}
-
-// Looks for clang-format in the 'buildtools' directory, falling back to PATH
-func findClangFormat(projectRoot string) string {
- var path string
- switch runtime.GOOS {
- case "linux":
- path = filepath.Join(projectRoot, "buildtools/linux64/clang-format")
- case "darwin":
- path = filepath.Join(projectRoot, "buildtools/mac/clang-format")
- case "windows":
- path = filepath.Join(projectRoot, "buildtools/win/clang-format.exe")
- }
- if fileutils.IsExe(path) {
- return path
- }
- var err error
- path, err = exec.LookPath("clang-format")
- if err == nil {
- return path
- }
- return ""
-}
diff --git a/tools/src/cmd/gen/templates/templates.go b/tools/src/cmd/gen/templates/templates.go
new file mode 100644
index 0000000..6af50c3
--- /dev/null
+++ b/tools/src/cmd/gen/templates/templates.go
@@ -0,0 +1,421 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package templates
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io"
+ "math/rand"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "reflect"
+ "runtime"
+ "strings"
+
+ "dawn.googlesource.com/dawn/tools/src/cmd/gen/common"
+ "dawn.googlesource.com/dawn/tools/src/container"
+ "dawn.googlesource.com/dawn/tools/src/fileutils"
+ "dawn.googlesource.com/dawn/tools/src/glob"
+ "dawn.googlesource.com/dawn/tools/src/template"
+ "dawn.googlesource.com/dawn/tools/src/tint/intrinsic/gen"
+ "dawn.googlesource.com/dawn/tools/src/tint/intrinsic/parser"
+ "dawn.googlesource.com/dawn/tools/src/tint/intrinsic/resolver"
+ "dawn.googlesource.com/dawn/tools/src/tint/intrinsic/sem"
+)
+
+func init() {
+ common.Register(&Cmd{})
+}
+
+type Cmd struct {
+}
+
+func (Cmd) Name() string {
+ return "templates"
+}
+
+func (Cmd) Desc() string {
+ return `templates generates files from <file>.tmpl files found in the Tint source and test directories`
+}
+
+func (c *Cmd) RegisterFlags(ctx context.Context, cfg *common.Config) ([]string, error) {
+ return nil, nil
+}
+
+func (c Cmd) Run(ctx context.Context, cfg *common.Config) error {
+ staleFiles := []string{}
+ projectRoot := fileutils.DawnRoot()
+
+ // Find clang-format
+ clangFormatPath := findClangFormat(projectRoot)
+ if clangFormatPath == "" {
+ return fmt.Errorf("cannot find clang-format in <dawn>/buildtools nor PATH")
+ }
+
+ files := flag.Args()
+ if len(files) == 0 {
+ // Recursively find all the template files in the <dawn>/src/tint and
+ // <dawn>/test/tint and directories
+ var err error
+ files, err = glob.Scan(projectRoot, glob.MustParseConfig(`{
+ "paths": [{"include": [
+ "src/tint/**.tmpl",
+ "test/tint/**.tmpl"
+ ]}]
+ }`))
+ if err != nil {
+ return err
+ }
+ } else {
+ // Make all template file paths project-relative
+ for i, f := range files {
+ abs, err := filepath.Abs(f)
+ if err != nil {
+ return fmt.Errorf("failed to get absolute file path for '%v': %w", f, err)
+ }
+ if !strings.HasPrefix(abs, projectRoot) {
+ return fmt.Errorf("template '%v' is not under project root '%v'", abs, projectRoot)
+ }
+ rel, err := filepath.Rel(projectRoot, abs)
+ if err != nil {
+ return fmt.Errorf("failed to get project relative file path for '%v': %w", f, err)
+ }
+ files[i] = rel
+ }
+ }
+
+ cache := &genCache{}
+
+ // For each template file...
+ for _, relTmplPath := range files { // relative to project root
+ if cfg.Flags.Verbose {
+ fmt.Println("processing", relTmplPath)
+ }
+ // Make tmplPath absolute
+ tmplPath := filepath.Join(projectRoot, relTmplPath)
+ tmplDir := filepath.Dir(tmplPath)
+
+ // Create or update the file at relPath if the file content has changed,
+ // preserving the copyright year in the header.
+ // relPath is a path relative to the template
+ writeFile := func(relPath, body string) error {
+ outPath := filepath.Join(tmplDir, relPath)
+
+ // Load the old file
+ existing, err := os.ReadFile(outPath)
+ if err != nil {
+ existing = nil
+ }
+
+ // Write the common file header
+ if cfg.Flags.Verbose {
+ fmt.Println(" writing", outPath)
+ }
+ sb := strings.Builder{}
+ sb.WriteString(common.Header(string(existing), filepath.ToSlash(relTmplPath), "//"))
+ sb.WriteString("\n")
+ sb.WriteString(body)
+ oldContent, newContent := string(existing), sb.String()
+
+ if oldContent != newContent {
+ if cfg.Flags.CheckStale {
+ staleFiles = append(staleFiles, outPath)
+ } else {
+ if err := os.MkdirAll(filepath.Dir(outPath), 0777); err != nil {
+ return fmt.Errorf("failed to create directory for '%v': %w", outPath, err)
+ }
+ if err := os.WriteFile(outPath, []byte(newContent), 0666); err != nil {
+ return fmt.Errorf("failed to write file '%v': %w", outPath, err)
+ }
+ }
+ }
+
+ return nil
+ }
+
+ // Write the content generated using the template and semantic info
+ sb := strings.Builder{}
+ if err := generate(tmplPath, cache, &sb, writeFile); err != nil {
+ return fmt.Errorf("while processing '%v': %w", tmplPath, err)
+ }
+
+ if body := sb.String(); body != "" {
+ _, tmplFileName := filepath.Split(tmplPath)
+ outFileName := strings.TrimSuffix(tmplFileName, ".tmpl")
+
+ switch filepath.Ext(outFileName) {
+ case ".cc", ".h", ".inl":
+ var err error
+ body, err = clangFormat(body, clangFormatPath)
+ if err != nil {
+ return err
+ }
+ }
+
+ if err := writeFile(outFileName, body); err != nil {
+ return err
+ }
+ }
+ }
+
+ if len(staleFiles) > 0 {
+ fmt.Println(len(staleFiles), "files need regenerating:")
+ for _, path := range staleFiles {
+ if rel, err := filepath.Rel(projectRoot, path); err == nil {
+ fmt.Println(" •", rel)
+ } else {
+ fmt.Println(" •", path)
+ }
+ }
+ fmt.Println("Regenerate these files with: ./tools/run gen")
+ os.Exit(1)
+ }
+
+ return nil
+}
+
+type intrinsicCache struct {
+ path string
+ cachedSem *sem.Sem // lazily built by sem()
+ cachedTable *gen.IntrinsicTable // lazily built by intrinsicTable()
+ cachedPermuter *gen.Permuter // lazily built by permute()
+}
+
+// Sem lazily parses and resolves the intrinsic.def file, returning the semantic info.
+func (i *intrinsicCache) Sem() (*sem.Sem, error) {
+ if i.cachedSem == nil {
+ // Load the intrinsic definition file
+ defPath := filepath.Join(fileutils.DawnRoot(), i.path)
+
+ defSource, err := os.ReadFile(defPath)
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse the definition file to produce an AST
+ ast, err := parser.Parse(string(defSource), i.path)
+ if err != nil {
+ return nil, err
+ }
+
+ // Resolve the AST to produce the semantic info
+ sem, err := resolver.Resolve(ast)
+ if err != nil {
+ return nil, err
+ }
+
+ i.cachedSem = sem
+ }
+ return i.cachedSem, nil
+}
+
+// Table lazily calls and returns the result of BuildIntrinsicTable(),
+// caching the result for repeated calls.
+func (i *intrinsicCache) Table() (*gen.IntrinsicTable, error) {
+ if i.cachedTable == nil {
+ sem, err := i.Sem()
+ if err != nil {
+ return nil, err
+ }
+ i.cachedTable, err = gen.BuildIntrinsicTable(sem)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return i.cachedTable, nil
+}
+
+// Permute lazily calls NewPermuter(), caching the result for repeated calls,
+// then passes the argument to Permutator.Permute()
+func (i *intrinsicCache) Permute(overload *sem.Overload) ([]gen.Permutation, error) {
+ if i.cachedPermuter == nil {
+ sem, err := i.Sem()
+ if err != nil {
+ return nil, err
+ }
+ i.cachedPermuter, err = gen.NewPermuter(sem)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return i.cachedPermuter.Permute(overload)
+}
+
+// Cache for objects that are expensive to build, and can be reused between templates.
+type genCache struct {
+ intrinsicsCache container.Map[string, *intrinsicCache]
+}
+
+func (g *genCache) intrinsics(path string) *intrinsicCache {
+ if g.intrinsicsCache == nil {
+ g.intrinsicsCache = container.NewMap[string, *intrinsicCache]()
+ }
+ i := g.intrinsicsCache[path]
+ if i == nil {
+ i = &intrinsicCache{path: path}
+ g.intrinsicsCache[path] = i
+ }
+ return i
+}
+
+type generator struct {
+ cache *genCache
+ writeFile WriteFile
+ rnd *rand.Rand
+}
+
+// WriteFile is a function that Generate() may call to emit a new file from a
+// template.
+// relPath is the relative path from the currently executing template.
+// content is the file content to write.
+type WriteFile func(relPath, content string) error
+
+// generate executes the template tmpl, writing the output to w.
+// See https://golang.org/pkg/text/template/ for documentation on the template
+// syntax.
+func generate(tmplPath string, cache *genCache, w io.Writer, writeFile WriteFile) error {
+ g := generator{
+ cache: cache,
+ writeFile: writeFile,
+ rnd: rand.New(rand.NewSource(4561123)),
+ }
+
+ funcs := map[string]any{
+ "SplitDisplayName": gen.SplitDisplayName,
+ "Scramble": g.scramble,
+ "IsEnumEntry": is(sem.EnumEntry{}),
+ "IsEnumMatcher": is(sem.EnumMatcher{}),
+ "IsFQN": is(sem.FullyQualifiedName{}),
+ "IsInt": is(1),
+ "IsTemplateEnumParam": is(sem.TemplateEnumParam{}),
+ "IsTemplateNumberParam": is(sem.TemplateNumberParam{}),
+ "IsTemplateTypeParam": is(sem.TemplateTypeParam{}),
+ "IsType": is(sem.Type{}),
+ "ElementType": gen.ElementType,
+ "DeepestElementType": gen.DeepestElementType,
+ "IsAbstract": gen.IsAbstract,
+ "IsDeclarable": gen.IsDeclarable,
+ "IsHostShareable": gen.IsHostShareable,
+ "OverloadUsesF16": gen.OverloadUsesF16,
+ "OverloadUsesReadWriteStorageTexture": gen.OverloadUsesReadWriteStorageTexture,
+ "IsFirstIn": isFirstIn,
+ "IsLastIn": isLastIn,
+ "LoadIntrinsics": func(path string) *intrinsicCache { return g.cache.intrinsics(path) },
+ "WriteFile": func(relPath, content string) (string, error) { return "", g.writeFile(relPath, content) },
+ }
+ t, err := template.FromFile(tmplPath)
+ if err != nil {
+ return err
+ }
+ return t.Run(w, nil, funcs)
+}
+
+// scramble randomly modifies the input string so that it is no longer equal to
+// any of the strings in 'avoid'.
+func (g *generator) scramble(str string, avoid container.Set[string]) (string, error) {
+ bytes := []byte(str)
+ passes := g.rnd.Intn(5) + 1
+
+ const chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz"
+
+ char := func() byte { return chars[g.rnd.Intn(len(chars))] }
+ replace := func(at int) { bytes[at] = char() }
+ delete := func(at int) { bytes = append(bytes[:at], bytes[at+1:]...) }
+ insert := func(at int) { bytes = append(append(bytes[:at], char()), bytes[at:]...) }
+
+ for i := 0; i < passes || avoid.Contains(string(bytes)); i++ {
+ if len(bytes) > 0 {
+ at := g.rnd.Intn(len(bytes))
+ switch g.rnd.Intn(3) {
+ case 0:
+ replace(at)
+ case 1:
+ delete(at)
+ case 2:
+ insert(at)
+ }
+ } else {
+ insert(0)
+ }
+ }
+ return string(bytes), nil
+}
+
+// is returns a function that returns true if the value passed to the function
+// matches the type of 'ty'.
+func is(ty any) func(any) bool {
+ rty := reflect.TypeOf(ty)
+ return func(v any) bool {
+ ty := reflect.TypeOf(v)
+ return ty == rty || ty == reflect.PtrTo(rty)
+ }
+}
+
+// isFirstIn returns true if v is the first element of the given slice.
+func isFirstIn(v, slice any) bool {
+ s := reflect.ValueOf(slice)
+ count := s.Len()
+ if count == 0 {
+ return false
+ }
+ return s.Index(0).Interface() == v
+}
+
+// isFirstIn returns true if v is the last element of the given slice.
+func isLastIn(v, slice any) bool {
+ s := reflect.ValueOf(slice)
+ count := s.Len()
+ if count == 0 {
+ return false
+ }
+ return s.Index(count-1).Interface() == v
+}
+
+// Invokes the clang-format executable at 'exe' to format the file content 'in'.
+// Returns the formatted file.
+func clangFormat(in, exe string) (string, error) {
+ cmd := exec.Command(exe)
+ cmd.Stdin = strings.NewReader(in)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ return "", fmt.Errorf("clang-format failed:\n%v\n%v", string(out), err)
+ }
+ return string(out), nil
+}
+
+// Looks for clang-format in the 'buildtools' directory, falling back to PATH
+func findClangFormat(projectRoot string) string {
+ var path string
+ switch runtime.GOOS {
+ case "linux":
+ path = filepath.Join(projectRoot, "buildtools/linux64/clang-format")
+ case "darwin":
+ path = filepath.Join(projectRoot, "buildtools/mac/clang-format")
+ case "windows":
+ path = filepath.Join(projectRoot, "buildtools/win/clang-format.exe")
+ }
+ if fileutils.IsExe(path) {
+ return path
+ }
+ var err error
+ path, err = exec.LookPath("clang-format")
+ if err == nil {
+ return path
+ }
+ return ""
+}
diff --git a/tools/src/cmd/tint-bench/main.go b/tools/src/cmd/tint-bench/main.go
index 6a01616..9d9eb40 100644
--- a/tools/src/cmd/tint-bench/main.go
+++ b/tools/src/cmd/tint-bench/main.go
@@ -19,7 +19,6 @@
import (
"flag"
"fmt"
- "io/ioutil"
"os"
"os/exec"
"path/filepath"
@@ -71,12 +70,12 @@
return fmt.Errorf("missing template path")
}
- tmpl, err := ioutil.ReadFile(tmplPath)
+ tmpl, err := template.FromFile(tmplPath)
if err != nil {
if !filepath.IsAbs(tmplPath) {
// Try relative to this .go file
tmplPath = filepath.Join(fileutils.ThisDir(), tmplPath)
- tmpl, err = ioutil.ReadFile(tmplPath)
+ tmpl, err = template.FromFile(tmplPath)
}
}
if err != nil {
@@ -97,7 +96,7 @@
funcs := template.Functions{
"Alpha": func() int { return alpha },
}
- wgslPath, err := writeWGSLFile(string(tmpl), funcs)
+ wgslPath, err := writeWGSLFile(tmpl, funcs)
if err != nil {
return err
}
@@ -123,14 +122,14 @@
return nil
}
-func writeWGSLFile(tmpl string, funcs template.Functions) (string, error) {
+func writeWGSLFile(tmpl *template.Template, funcs template.Functions) (string, error) {
const path = "tint-bench.wgsl"
wgslFile, err := os.Create(path)
if err != nil {
return "", fmt.Errorf("failed to create benchmark WGSL test file: %w", err)
}
defer wgslFile.Close()
- if err := template.Run(tmpl, wgslFile, funcs); err != nil {
+ if err := tmpl.Run(wgslFile, nil, funcs); err != nil {
return "", fmt.Errorf("template error:\n%w", err)
}
return path, nil
diff --git a/tools/src/subcmd/subcmd.go b/tools/src/subcmd/subcmd.go
index 88f1158..a5db01e 100644
--- a/tools/src/subcmd/subcmd.go
+++ b/tools/src/subcmd/subcmd.go
@@ -120,7 +120,7 @@
}
if profile {
fmt.Println("download profile at: localhost:8080/profile")
- fmt.Println("then run: 'go tool pprof <file>")
+ fmt.Println("then run: 'go tool pprof <file>'")
go http.ListenAndServe(":8080", mux)
}
return cmd.Run(ctx, data)
diff --git a/tools/src/template/template.go b/tools/src/template/template.go
index 36c8bae..0bd2180 100644
--- a/tools/src/template/template.go
+++ b/tools/src/template/template.go
@@ -20,7 +20,9 @@
"fmt"
"io"
"io/ioutil"
+ "os"
"path/filepath"
+ "reflect"
"strings"
"text/template"
"unicode"
@@ -29,15 +31,35 @@
)
// The template function binding table
-type Functions map[string]interface{}
+type Functions = template.FuncMap
+
+type Template struct {
+ name string
+ content string
+}
+
+// FromFile loads the template file at path and builds and returns a Template
+// using the file content
+func FromFile(path string) (*Template, error) {
+ content, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+ return FromString(path, string(content)), nil
+}
+
+// FromString returns a Template with the given name from content
+func FromString(name, content string) *Template {
+ return &Template{name: name, content: content}
+}
// Run executes the template tmpl, writing the output to w.
// funcs are the functions provided to the template.
// See https://golang.org/pkg/text/template/ for documentation on the template
// syntax.
-func Run(tmpl string, w io.Writer, funcs Functions) error {
+func (t *Template) Run(w io.Writer, data any, funcs Functions) error {
g := generator{
- template: template.New("<template>"),
+ template: template.New(t.name),
}
globals := newMap()
@@ -53,12 +75,18 @@
"Iterate": iterate,
"Map": newMap,
"PascalCase": pascalCase,
+ "ToUpper": strings.ToUpper,
+ "ToLower": strings.ToLower,
+ "Repeat": strings.Repeat,
"Split": strings.Split,
"Title": strings.Title,
"TrimLeft": strings.TrimLeft,
"TrimPrefix": strings.TrimPrefix,
"TrimRight": strings.TrimRight,
"TrimSuffix": strings.TrimSuffix,
+ "Replace": strings.ReplaceAll,
+ "Index": index,
+ "Error": func(err any) string { panic(err) },
}
// Append custom functions
@@ -66,11 +94,11 @@
g.funcs[name] = fn
}
- if err := g.bindAndParse(g.template, tmpl); err != nil {
+ if err := g.bindAndParse(g.template, t.content); err != nil {
return err
}
- return g.template.Execute(w, nil)
+ return g.template.Execute(w, data)
}
type generator struct {
@@ -195,3 +223,27 @@
}
return b.String()
}
+
+func index(obj any, indices ...any) (any, error) {
+ v := reflect.ValueOf(obj)
+ for _, idx := range indices {
+ for v.Kind() == reflect.Interface || v.Kind() == reflect.Pointer {
+ v = v.Elem()
+ }
+ if !v.IsValid() || v.IsZero() || v.IsNil() {
+ return nil, nil
+ }
+ switch v.Kind() {
+ case reflect.Array, reflect.Slice:
+ v = v.Index(idx.(int))
+ case reflect.Map:
+ v = v.MapIndex(reflect.ValueOf(idx))
+ default:
+ return nil, fmt.Errorf("cannot index %T (%v)", obj, v.Kind())
+ }
+ }
+ if !v.IsValid() || v.IsZero() || v.IsNil() {
+ return nil, nil
+ }
+ return v.Interface(), nil
+}
diff --git a/tools/src/template/template_test.go b/tools/src/template/template_test.go
index f18e727..d4f94c4 100644
--- a/tools/src/template/template_test.go
+++ b/tools/src/template/template_test.go
@@ -22,12 +22,11 @@
"github.com/google/go-cmp/cmp"
)
-func check(t *testing.T, tmpl, expected string, fns template.Functions) {
+func check(t *testing.T, content, expected string, fns template.Functions) {
t.Helper()
w := &bytes.Buffer{}
- err := template.Run(tmpl, w, fns)
- if err != nil {
- t.Errorf("template.Run() failed with %v", err)
+ if err := template.FromString("template", content).Run(w, nil, fns); err != nil {
+ t.Errorf("Template.Run() failed with %v", err)
return
}
got := w.String()