| // Copyright 2021 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| // test-runner runs tint against a number of test shaders checking for expected behavior |
| package main |
| |
| import ( |
| "context" |
| "flag" |
| "fmt" |
| "io/ioutil" |
| "os" |
| "os/exec" |
| "path/filepath" |
| "runtime" |
| "sort" |
| "strings" |
| "time" |
| "unicode/utf8" |
| |
| "dawn.googlesource.com/tint/tools/src/fileutils" |
| "dawn.googlesource.com/tint/tools/src/glob" |
| "github.com/fatih/color" |
| "github.com/sergi/go-diff/diffmatchpatch" |
| ) |
| |
| type outputFormat string |
| |
| const ( |
| testTimeout = 30 * time.Second |
| |
| wgsl = outputFormat("wgsl") |
| spvasm = outputFormat("spvasm") |
| msl = outputFormat("msl") |
| hlsl = outputFormat("hlsl") |
| ) |
| |
| func main() { |
| if err := run(); err != nil { |
| fmt.Println(err) |
| os.Exit(1) |
| } |
| } |
| |
| func showUsage() { |
| fmt.Println(` |
| test-runner runs tint against a number of test shaders checking for expected behavior |
| |
| usage: |
| test-runner [flags...] <executable> [<directory>] |
| |
| <executable> the path to the tint executable |
| <directory> the root directory of the test files |
| |
| optional flags:`) |
| flag.PrintDefaults() |
| fmt.Println(``) |
| os.Exit(1) |
| } |
| |
| func run() error { |
| var formatList, filter, dxcPath, xcrunPath string |
| var maxFilenameColumnWidth int |
| numCPU := runtime.NumCPU() |
| fxc, verbose, generateExpected, generateSkip := false, false, false, false |
| flag.StringVar(&formatList, "format", "all", "comma separated list of formats to emit. Possible values are: all, wgsl, spvasm, msl, hlsl") |
| flag.StringVar(&filter, "filter", "**.wgsl, **.spvasm, **.spv", "comma separated list of glob patterns for test files") |
| flag.StringVar(&dxcPath, "dxc", "", "path to DXC executable for validating HLSL output") |
| flag.StringVar(&xcrunPath, "xcrun", "", "path to xcrun executable for validating MSL output") |
| flag.BoolVar(&fxc, "fxc", false, "validate with FXC instead of DXC") |
| flag.BoolVar(&verbose, "verbose", false, "print all run tests, including rows that all pass") |
| flag.BoolVar(&generateExpected, "generate-expected", false, "create or update all expected outputs") |
| flag.BoolVar(&generateSkip, "generate-skip", false, "create or update all expected outputs that fail with SKIP") |
| flag.IntVar(&numCPU, "j", numCPU, "maximum number of concurrent threads to run tests") |
| flag.IntVar(&maxFilenameColumnWidth, "filename-column-width", 0, "maximum width of the filename column") |
| flag.Usage = showUsage |
| flag.Parse() |
| |
| args := flag.Args() |
| if len(args) == 0 { |
| showUsage() |
| } |
| |
| // executable path is the first argument |
| exe, args := args[0], args[1:] |
| |
| // (optional) target directory is the second argument |
| dir := "." |
| if len(args) > 0 { |
| dir, args = args[0], args[1:] |
| } |
| |
| // Check the executable can be found and actually is executable |
| if !fileutils.IsExe(exe) { |
| return fmt.Errorf("'%s' not found or is not executable", exe) |
| } |
| exe, err := filepath.Abs(exe) |
| if err != nil { |
| return err |
| } |
| |
| // Allow using '/' in the filter on Windows |
| filter = strings.ReplaceAll(filter, "/", string(filepath.Separator)) |
| |
| // Split the --filter flag up by ',', trimming any whitespace at the start and end |
| globIncludes := strings.Split(filter, ",") |
| for i, s := range globIncludes { |
| // Escape backslashes for the glob config |
| s = strings.ReplaceAll(s, `\`, `\\`) |
| globIncludes[i] = `"` + strings.TrimSpace(s) + `"` |
| } |
| |
| // Glob the files to test |
| files, err := glob.Scan(dir, glob.MustParseConfig(`{ |
| "paths": [ |
| { |
| "include": [ `+strings.Join(globIncludes, ",")+` ] |
| }, |
| { |
| "exclude": [ |
| "**.expected.wgsl", |
| "**.expected.spvasm", |
| "**.expected.msl", |
| "**.expected.hlsl" |
| ] |
| } |
| ] |
| }`)) |
| if err != nil { |
| return fmt.Errorf("Failed to glob files: %w", err) |
| } |
| |
| // Ensure the files are sorted (globbing should do this, but why not) |
| sort.Strings(files) |
| |
| // Parse --format into a list of outputFormat |
| formats := []outputFormat{} |
| if formatList == "all" { |
| formats = []outputFormat{wgsl, spvasm, msl, hlsl} |
| } else { |
| for _, f := range strings.Split(formatList, ",") { |
| switch strings.TrimSpace(f) { |
| case "wgsl": |
| formats = append(formats, wgsl) |
| case "spvasm": |
| formats = append(formats, spvasm) |
| case "msl": |
| formats = append(formats, msl) |
| case "hlsl": |
| formats = append(formats, hlsl) |
| default: |
| return fmt.Errorf("unknown format '%s'", f) |
| } |
| } |
| } |
| |
| defaultMSLExe := "xcrun" |
| if runtime.GOOS == "windows" { |
| defaultMSLExe = "metal.exe" |
| } |
| |
| // If explicit verification compilers have been specified, check they exist. |
| // Otherwise, look on PATH for them, but don't error if they cannot be found. |
| for _, tool := range []struct { |
| name string |
| lang string |
| path *string |
| }{ |
| {"dxc", "hlsl", &dxcPath}, |
| {defaultMSLExe, "msl", &xcrunPath}, |
| } { |
| if *tool.path == "" { |
| p, err := exec.LookPath(tool.name) |
| if err == nil && fileutils.IsExe(p) { |
| *tool.path = p |
| } |
| } else if !fileutils.IsExe(*tool.path) { |
| return fmt.Errorf("%v not found at '%v'", tool.name, *tool.path) |
| } |
| |
| color.Set(color.FgCyan) |
| fmt.Printf("%-4s", tool.lang) |
| color.Unset() |
| fmt.Printf(" validation ") |
| if *tool.path != "" || (fxc && tool.lang == "hlsl") { |
| color.Set(color.FgGreen) |
| fmt.Printf("ENABLED") |
| } else { |
| color.Set(color.FgRed) |
| fmt.Printf("DISABLED") |
| } |
| color.Unset() |
| fmt.Println() |
| } |
| fmt.Println() |
| |
| // Build the list of results. |
| // These hold the chans used to report the job results. |
| results := make([]map[outputFormat]chan status, len(files)) |
| for i := range files { |
| fileResults := map[outputFormat]chan status{} |
| for _, format := range formats { |
| fileResults[format] = make(chan status, 1) |
| } |
| results[i] = fileResults |
| } |
| |
| pendingJobs := make(chan job, 256) |
| |
| // Spawn numCPU job runners... |
| for cpu := 0; cpu < numCPU; cpu++ { |
| go func() { |
| for job := range pendingJobs { |
| job.run(dir, exe, fxc, dxcPath, xcrunPath, generateExpected, generateSkip) |
| } |
| }() |
| } |
| |
| // Issue the jobs... |
| go func() { |
| for i, file := range files { // For each test file... |
| file := filepath.Join(dir, file) |
| for _, format := range formats { // For each output format... |
| pendingJobs <- job{ |
| file: file, |
| format: format, |
| result: results[i][format], |
| } |
| } |
| } |
| close(pendingJobs) |
| }() |
| |
| type failure struct { |
| file string |
| format outputFormat |
| err error |
| } |
| |
| type stats struct { |
| numTests, numPass, numSkip, numFail int |
| } |
| |
| // Statistics per output format |
| statsByFmt := map[outputFormat]*stats{} |
| for _, format := range formats { |
| statsByFmt[format] = &stats{} |
| |
| } |
| |
| // Print the table of file x format and gather per-format stats |
| failures := []failure{} |
| filenameColumnWidth := maxStringLen(files) |
| if maxFilenameColumnWidth > 0 { |
| filenameColumnWidth = maxFilenameColumnWidth |
| } |
| |
| red := color.New(color.FgRed) |
| green := color.New(color.FgGreen) |
| yellow := color.New(color.FgYellow) |
| cyan := color.New(color.FgCyan) |
| |
| printFormatsHeader := func() { |
| fmt.Printf(strings.Repeat(" ", filenameColumnWidth)) |
| fmt.Printf(" ┃ ") |
| for _, format := range formats { |
| cyan.Printf(alignCenter(format, formatWidth(format))) |
| fmt.Printf(" │ ") |
| } |
| fmt.Println() |
| } |
| printHorizontalLine := func() { |
| fmt.Printf(strings.Repeat("━", filenameColumnWidth)) |
| fmt.Printf("━╋━") |
| for _, format := range formats { |
| fmt.Printf(strings.Repeat("━", formatWidth(format))) |
| fmt.Printf("━┿━") |
| } |
| fmt.Println() |
| } |
| |
| fmt.Println() |
| |
| printFormatsHeader() |
| printHorizontalLine() |
| |
| for i, file := range files { |
| results := results[i] |
| |
| row := &strings.Builder{} |
| rowAllPassed := true |
| |
| filenameLength := utf8.RuneCountInString(file) |
| shortFile := file |
| if filenameLength > filenameColumnWidth { |
| shortFile = "..." + file[filenameLength-filenameColumnWidth+3:] |
| } |
| |
| fmt.Fprintf(row, alignRight(shortFile, filenameColumnWidth)) |
| fmt.Fprintf(row, " ┃ ") |
| for _, format := range formats { |
| columnWidth := formatWidth(format) |
| result := <-results[format] |
| stats := statsByFmt[format] |
| stats.numTests++ |
| if err := result.err; err != nil { |
| failures = append(failures, failure{ |
| file: file, format: format, err: err, |
| }) |
| } |
| switch result.code { |
| case pass: |
| green.Fprintf(row, alignCenter("PASS", columnWidth)) |
| stats.numPass++ |
| case fail: |
| red.Fprintf(row, alignCenter("FAIL", columnWidth)) |
| rowAllPassed = false |
| stats.numFail++ |
| case skip: |
| yellow.Fprintf(row, alignCenter("SKIP", columnWidth)) |
| rowAllPassed = false |
| stats.numSkip++ |
| default: |
| fmt.Fprintf(row, alignCenter(result.code, columnWidth)) |
| rowAllPassed = false |
| } |
| fmt.Fprintf(row, " │ ") |
| } |
| |
| if verbose || !rowAllPassed { |
| fmt.Fprintln(color.Output, row) |
| } |
| } |
| |
| printHorizontalLine() |
| printFormatsHeader() |
| printHorizontalLine() |
| printStat := func(col *color.Color, name string, num func(*stats) int) { |
| row := &strings.Builder{} |
| anyNonZero := false |
| for _, format := range formats { |
| columnWidth := formatWidth(format) |
| count := num(statsByFmt[format]) |
| if count > 0 { |
| col.Fprintf(row, alignLeft(count, columnWidth)) |
| anyNonZero = true |
| } else { |
| fmt.Fprintf(row, alignLeft(count, columnWidth)) |
| } |
| fmt.Fprintf(row, " │ ") |
| } |
| |
| if !anyNonZero { |
| return |
| } |
| col.Printf(alignRight(name, filenameColumnWidth)) |
| fmt.Printf(" ┃ ") |
| fmt.Fprintln(color.Output, row) |
| |
| col.Printf(strings.Repeat(" ", filenameColumnWidth)) |
| fmt.Printf(" ┃ ") |
| for _, format := range formats { |
| columnWidth := formatWidth(format) |
| stats := statsByFmt[format] |
| count := num(stats) |
| percent := percentage(count, stats.numTests) |
| if count > 0 { |
| col.Print(alignRight(percent, columnWidth)) |
| } else { |
| fmt.Print(alignRight(percent, columnWidth)) |
| } |
| fmt.Printf(" │ ") |
| } |
| fmt.Println() |
| } |
| printStat(green, "PASS", func(s *stats) int { return s.numPass }) |
| printStat(yellow, "SKIP", func(s *stats) int { return s.numSkip }) |
| printStat(red, "FAIL", func(s *stats) int { return s.numFail }) |
| fmt.Println() |
| |
| for _, f := range failures { |
| color.Set(color.FgBlue) |
| fmt.Printf("%s ", f.file) |
| color.Set(color.FgCyan) |
| fmt.Printf("%s ", f.format) |
| color.Set(color.FgRed) |
| fmt.Println("FAIL") |
| color.Unset() |
| fmt.Println(indent(f.err.Error(), 4)) |
| } |
| if len(failures) > 0 { |
| fmt.Println() |
| } |
| |
| allStats := stats{} |
| for _, format := range formats { |
| stats := statsByFmt[format] |
| allStats.numTests += stats.numTests |
| allStats.numPass += stats.numPass |
| allStats.numSkip += stats.numSkip |
| allStats.numFail += stats.numFail |
| } |
| |
| fmt.Printf("%d tests run", allStats.numTests) |
| if allStats.numPass > 0 { |
| fmt.Printf(", ") |
| color.Set(color.FgGreen) |
| fmt.Printf("%d tests pass", allStats.numPass) |
| color.Unset() |
| } else { |
| fmt.Printf(", %d tests pass", allStats.numPass) |
| } |
| if allStats.numSkip > 0 { |
| fmt.Printf(", ") |
| color.Set(color.FgYellow) |
| fmt.Printf("%d tests skipped", allStats.numSkip) |
| color.Unset() |
| } else { |
| fmt.Printf(", %d tests skipped", allStats.numSkip) |
| } |
| if allStats.numFail > 0 { |
| fmt.Printf(", ") |
| color.Set(color.FgRed) |
| fmt.Printf("%d tests failed", allStats.numFail) |
| color.Unset() |
| } else { |
| fmt.Printf(", %d tests failed", allStats.numFail) |
| } |
| fmt.Println() |
| fmt.Println() |
| |
| if allStats.numFail > 0 { |
| os.Exit(1) |
| } |
| |
| return nil |
| } |
| |
| // Structures to hold the results of the tests |
| type statusCode string |
| |
| const ( |
| fail statusCode = "FAIL" |
| pass statusCode = "PASS" |
| skip statusCode = "SKIP" |
| ) |
| |
| type status struct { |
| code statusCode |
| err error |
| } |
| |
| type job struct { |
| file string |
| format outputFormat |
| result chan status |
| } |
| |
| func (j job) run(wd, exe string, fxc bool, dxcPath, xcrunPath string, generateExpected, generateSkip bool) { |
| j.result <- func() status { |
| // Is there an expected output? |
| expected := loadExpectedFile(j.file, j.format) |
| skipped := false |
| if strings.HasPrefix(expected, "SKIP") { // Special SKIP token |
| skipped = true |
| } |
| |
| expected = strings.ReplaceAll(expected, "\r\n", "\n") |
| |
| file, err := filepath.Rel(wd, j.file) |
| if err != nil { |
| file = j.file |
| } |
| |
| // Make relative paths use forward slash separators (on Windows) so that paths in tint |
| // output match expected output that contain errors |
| file = strings.ReplaceAll(file, `\`, `/`) |
| |
| args := []string{ |
| file, |
| "--format", string(j.format), |
| } |
| |
| // Can we validate? |
| validate := false |
| switch j.format { |
| case wgsl: |
| validate = true |
| case spvasm: |
| args = append(args, "--validate") // spirv-val is statically linked, always available |
| validate = true |
| case hlsl: |
| if fxc { |
| args = append(args, "--fxc") |
| validate = true |
| } else if dxcPath != "" { |
| args = append(args, "--dxc", dxcPath) |
| validate = true |
| } |
| case msl: |
| if xcrunPath != "" { |
| args = append(args, "--xcrun", xcrunPath) |
| validate = true |
| } |
| } |
| |
| // Invoke the compiler... |
| ok, out := invoke(wd, exe, args...) |
| out = strings.ReplaceAll(out, "\r\n", "\n") |
| matched := expected == "" || expected == out |
| |
| if ok && generateExpected && (validate || !skipped) { |
| saveExpectedFile(j.file, j.format, out) |
| matched = true |
| } |
| |
| switch { |
| case ok && matched: |
| // Test passed |
| return status{code: pass} |
| |
| // --- Below this point the test has failed --- |
| |
| case skipped: |
| if generateSkip { |
| saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out) |
| } |
| return status{code: skip} |
| |
| case !ok: |
| // Compiler returned non-zero exit code |
| if generateSkip { |
| saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out) |
| } |
| err := fmt.Errorf("%s", out) |
| return status{code: fail, err: err} |
| |
| default: |
| // Compiler returned zero exit code, or output was not as expected |
| if generateSkip { |
| saveExpectedFile(j.file, j.format, "SKIP: FAILED\n\n"+out) |
| } |
| |
| // Expected output did not match |
| dmp := diffmatchpatch.New() |
| diff := dmp.DiffPrettyText(dmp.DiffMain(expected, out, true)) |
| err := fmt.Errorf(`Output was not as expected |
| |
| -------------------------------------------------------------------------------- |
| -- Expected: -- |
| -------------------------------------------------------------------------------- |
| %s |
| |
| -------------------------------------------------------------------------------- |
| -- Got: -- |
| -------------------------------------------------------------------------------- |
| %s |
| |
| -------------------------------------------------------------------------------- |
| -- Diff: -- |
| -------------------------------------------------------------------------------- |
| %s`, |
| expected, out, diff) |
| return status{code: fail, err: err} |
| } |
| }() |
| } |
| |
| // loadExpectedFile loads the expected output file for the test file at 'path' |
| // and the output format 'format'. If the file does not exist, or cannot be |
| // read, then an empty string is returned. |
| func loadExpectedFile(path string, format outputFormat) string { |
| content, err := ioutil.ReadFile(expectedFilePath(path, format)) |
| if err != nil { |
| return "" |
| } |
| return string(content) |
| } |
| |
| // saveExpectedFile writes the expected output file for the test file at 'path' |
| // and the output format 'format', with the content 'content'. |
| func saveExpectedFile(path string, format outputFormat, content string) error { |
| return ioutil.WriteFile(expectedFilePath(path, format), []byte(content), 0666) |
| } |
| |
| // expectedFilePath returns the expected output file path for the test file at |
| // 'path' and the output format 'format'. |
| func expectedFilePath(path string, format outputFormat) string { |
| return path + ".expected." + string(format) |
| } |
| |
| // indent returns the string 's' indented with 'n' whitespace characters |
| func indent(s string, n int) string { |
| tab := strings.Repeat(" ", n) |
| return tab + strings.ReplaceAll(s, "\n", "\n"+tab) |
| } |
| |
| // alignLeft returns the string of 'val' padded so that it is aligned left in |
| // a column of the given width |
| func alignLeft(val interface{}, width int) string { |
| s := fmt.Sprint(val) |
| padding := width - utf8.RuneCountInString(s) |
| return s + strings.Repeat(" ", padding) |
| } |
| |
| // alignCenter returns the string of 'val' padded so that it is centered in a |
| // column of the given width. |
| func alignCenter(val interface{}, width int) string { |
| s := fmt.Sprint(val) |
| padding := width - utf8.RuneCountInString(s) |
| return strings.Repeat(" ", padding/2) + s + strings.Repeat(" ", (padding+1)/2) |
| } |
| |
| // alignRight returns the string of 'val' padded so that it is aligned right in |
| // a column of the given width |
| func alignRight(val interface{}, width int) string { |
| s := fmt.Sprint(val) |
| padding := width - utf8.RuneCountInString(s) |
| return strings.Repeat(" ", padding) + s |
| } |
| |
| // maxStringLen returns the maximum number of runes found in all the strings in |
| // 'l' |
| func maxStringLen(l []string) int { |
| max := 0 |
| for _, s := range l { |
| if c := utf8.RuneCountInString(s); c > max { |
| max = c |
| } |
| } |
| return max |
| } |
| |
| // formatWidth returns the width in runes for the outputFormat column 'b' |
| func formatWidth(b outputFormat) int { |
| const min = 6 |
| c := utf8.RuneCountInString(string(b)) |
| if c < min { |
| return min |
| } |
| return c |
| } |
| |
| // percentage returns the percentage of n out of total as a string |
| func percentage(n, total int) string { |
| if total == 0 { |
| return "-" |
| } |
| f := float64(n) / float64(total) |
| return fmt.Sprintf("%.1f%c", f*100.0, '%') |
| } |
| |
| // invoke runs the executable 'exe' with the provided arguments. |
| func invoke(wd, exe string, args ...string) (ok bool, output string) { |
| ctx, cancel := context.WithTimeout(context.Background(), testTimeout) |
| defer cancel() |
| |
| cmd := exec.CommandContext(ctx, exe, args...) |
| cmd.Dir = wd |
| out, err := cmd.CombinedOutput() |
| str := string(out) |
| if err != nil { |
| if ctx.Err() == context.DeadlineExceeded { |
| return false, fmt.Sprintf("test timed out after %v", testTimeout) |
| } |
| if str != "" { |
| return false, str |
| } |
| return false, err.Error() |
| } |
| return true, str |
| } |