blob: 32be5cff7cda971013303152466630d588ac2ec8 [file] [log] [blame]
Austin Enge58d5a32021-01-27 22:54:04 +00001// Copyright 2021 The Dawn Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "utils/WireHelper.h"
16
17#include "common/Assert.h"
18#include "common/Log.h"
19#include "common/SystemUtils.h"
20#include "dawn/dawn_proc.h"
21#include "dawn_native/DawnNative.h"
22#include "dawn_wire/WireClient.h"
23#include "dawn_wire/WireServer.h"
24#include "utils/TerribleCommandBuffer.h"
25
Corentin Wallez0859e6b2021-01-29 12:50:36 +000026#include <algorithm>
Corentin Wallez42f28d62021-02-09 16:13:51 +000027#include <cstring>
Austin Enge58d5a32021-01-27 22:54:04 +000028#include <fstream>
29#include <iomanip>
30#include <set>
31#include <sstream>
32
33namespace utils {
34
35 namespace {
36
37 class WireServerTraceLayer : public dawn_wire::CommandHandler {
38 public:
39 WireServerTraceLayer(const char* dir, dawn_wire::CommandHandler* handler)
40 : dawn_wire::CommandHandler(), mDir(dir), mHandler(handler) {
41 const char* sep = GetPathSeparator();
Austin Engc00bc902021-02-03 19:25:02 +000042 if (mDir.size() > 0 && mDir.back() != *sep) {
Austin Enge58d5a32021-01-27 22:54:04 +000043 mDir += sep;
44 }
45 }
46
47 void BeginWireTrace(const char* name) {
48 std::string filename = name;
49 // Replace slashes in gtest names with underscores so everything is in one
50 // directory.
51 std::replace(filename.begin(), filename.end(), '/', '_');
52 std::replace(filename.begin(), filename.end(), '\\', '_');
53
54 // Prepend the filename with the directory.
55 filename = mDir + filename;
56
57 ASSERT(!mFile.is_open());
58 mFile.open(filename,
59 std::ios_base::out | std::ios_base::binary | std::ios_base::trunc);
Austin Engc00bc902021-02-03 19:25:02 +000060
61 // Write the initial 8 bytes. This means the fuzzer should never inject an
62 // error.
63 const uint64_t injectedErrorIndex = 0xFFFF'FFFF'FFFF'FFFF;
64 mFile.write(reinterpret_cast<const char*>(&injectedErrorIndex),
65 sizeof(injectedErrorIndex));
Austin Enge58d5a32021-01-27 22:54:04 +000066 }
67
68 const volatile char* HandleCommands(const volatile char* commands,
69 size_t size) override {
70 if (mFile.is_open()) {
71 mFile.write(const_cast<const char*>(commands), size);
72 }
73 return mHandler->HandleCommands(commands, size);
74 }
75
76 private:
77 std::string mDir;
78 dawn_wire::CommandHandler* mHandler;
79 std::ofstream mFile;
80 };
81
82 class WireHelperDirect : public WireHelper {
83 public:
84 WireHelperDirect() {
85 dawnProcSetProcs(&dawn_native::GetProcs());
86 }
87
88 std::pair<wgpu::Device, WGPUDevice> RegisterDevice(WGPUDevice backendDevice) override {
89 ASSERT(backendDevice != nullptr);
90 return std::make_pair(wgpu::Device::Acquire(backendDevice), backendDevice);
91 }
92
93 void BeginWireTrace(const char* name) override {
94 }
95
96 bool FlushClient() override {
97 return true;
98 }
99
100 bool FlushServer() override {
101 return true;
102 }
103 };
104
105 class WireHelperProxy : public WireHelper {
106 public:
107 explicit WireHelperProxy(const char* wireTraceDir) {
108 mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>();
109 mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();
110
111 dawn_wire::WireServerDescriptor serverDesc = {};
112 serverDesc.procs = &dawn_native::GetProcs();
113 serverDesc.serializer = mS2cBuf.get();
114
115 mWireServer.reset(new dawn_wire::WireServer(serverDesc));
116 mC2sBuf->SetHandler(mWireServer.get());
117
118 if (wireTraceDir != nullptr && strlen(wireTraceDir) > 0) {
119 mWireServerTraceLayer.reset(
120 new WireServerTraceLayer(wireTraceDir, mWireServer.get()));
121 mC2sBuf->SetHandler(mWireServerTraceLayer.get());
122 }
123
124 dawn_wire::WireClientDescriptor clientDesc = {};
125 clientDesc.serializer = mC2sBuf.get();
126
127 mWireClient.reset(new dawn_wire::WireClient(clientDesc));
128 mS2cBuf->SetHandler(mWireClient.get());
129 dawnProcSetProcs(&dawn_wire::client::GetProcs());
130 }
131
132 std::pair<wgpu::Device, WGPUDevice> RegisterDevice(WGPUDevice backendDevice) override {
133 ASSERT(backendDevice != nullptr);
134
135 auto reservation = mWireClient->ReserveDevice();
136 mWireServer->InjectDevice(backendDevice, reservation.id, reservation.generation);
137 dawn_native::GetProcs().deviceRelease(backendDevice);
138
139 return std::make_pair(wgpu::Device::Acquire(reservation.device), backendDevice);
140 }
141
142 void BeginWireTrace(const char* name) override {
143 if (mWireServerTraceLayer) {
144 return mWireServerTraceLayer->BeginWireTrace(name);
145 }
146 }
147
148 bool FlushClient() override {
149 return mC2sBuf->Flush();
150 }
151
152 bool FlushServer() override {
153 return mS2cBuf->Flush();
154 }
155
156 private:
157 std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf;
158 std::unique_ptr<utils::TerribleCommandBuffer> mS2cBuf;
159 std::unique_ptr<WireServerTraceLayer> mWireServerTraceLayer;
160 std::unique_ptr<dawn_wire::WireServer> mWireServer;
161 std::unique_ptr<dawn_wire::WireClient> mWireClient;
162 };
163
164 } // anonymous namespace
165
166 std::unique_ptr<WireHelper> CreateWireHelper(bool useWire, const char* wireTraceDir) {
167 if (useWire) {
168 return std::unique_ptr<WireHelper>(new WireHelperProxy(wireTraceDir));
169 } else {
170 return std::unique_ptr<WireHelper>(new WireHelperDirect());
171 }
172 }
173
174 WireHelper::~WireHelper() {
175 dawnProcSetProcs(nullptr);
176 }
177
178} // namespace utils