1 package org.junit.internal.runners.statements;
2
3 import java.lang.management.ManagementFactory;
4 import java.lang.management.ThreadMXBean;
5 import java.util.Arrays;
6 import java.util.concurrent.Callable;
7 import java.util.concurrent.CountDownLatch;
8 import java.util.concurrent.ExecutionException;
9 import java.util.concurrent.FutureTask;
10 import java.util.concurrent.TimeUnit;
11 import java.util.concurrent.TimeoutException;
12
13 import org.junit.runners.model.MultipleFailureException;
14 import org.junit.runners.model.Statement;
15 import org.junit.runners.model.TestTimedOutException;
16
17 public class FailOnTimeout extends Statement {
18 private final Statement originalStatement;
19 private final TimeUnit timeUnit;
20 private final long timeout;
21 private final boolean lookForStuckThread;
22 private volatile ThreadGroup threadGroup = null;
23
24
25
26
27
28
29 public static Builder builder() {
30 return new Builder();
31 }
32
33
34
35
36
37
38
39
40 @Deprecated
41 public FailOnTimeout(Statement statement, long timeoutMillis) {
42 this(builder().withTimeout(timeoutMillis, TimeUnit.MILLISECONDS), statement);
43 }
44
45 private FailOnTimeout(Builder builder, Statement statement) {
46 originalStatement = statement;
47 timeout = builder.timeout;
48 timeUnit = builder.unit;
49 lookForStuckThread = builder.lookForStuckThread;
50 }
51
52
53
54
55
56
57 public static class Builder {
58 private boolean lookForStuckThread = false;
59 private long timeout = 0;
60 private TimeUnit unit = TimeUnit.SECONDS;
61
62 private Builder() {
63 }
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 public Builder withTimeout(long timeout, TimeUnit unit) {
79 if (timeout < 0) {
80 throw new IllegalArgumentException("timeout must be non-negative");
81 }
82 if (unit == null) {
83 throw new NullPointerException("TimeUnit cannot be null");
84 }
85 this.timeout = timeout;
86 this.unit = unit;
87 return this;
88 }
89
90
91
92
93
94
95
96
97
98
99 public Builder withLookingForStuckThread(boolean enable) {
100 this.lookForStuckThread = enable;
101 return this;
102 }
103
104
105
106
107
108
109
110 public FailOnTimeout build(Statement statement) {
111 if (statement == null) {
112 throw new NullPointerException("statement cannot be null");
113 }
114 return new FailOnTimeout(this, statement);
115 }
116 }
117
118 @Override
119 public void evaluate() throws Throwable {
120 CallableStatement callable = new CallableStatement();
121 FutureTask<Throwable> task = new FutureTask<Throwable>(callable);
122 threadGroup = new ThreadGroup("FailOnTimeoutGroup");
123 Thread thread = new Thread(threadGroup, task, "Time-limited test");
124 thread.setDaemon(true);
125 thread.start();
126 callable.awaitStarted();
127 Throwable throwable = getResult(task, thread);
128 if (throwable != null) {
129 throw throwable;
130 }
131 }
132
133
134
135
136
137
138 private Throwable getResult(FutureTask<Throwable> task, Thread thread) {
139 try {
140 if (timeout > 0) {
141 return task.get(timeout, timeUnit);
142 } else {
143 return task.get();
144 }
145 } catch (InterruptedException e) {
146 return e;
147 } catch (ExecutionException e) {
148
149 return e.getCause();
150 } catch (TimeoutException e) {
151 return createTimeoutException(thread);
152 }
153 }
154
155 private Exception createTimeoutException(Thread thread) {
156 StackTraceElement[] stackTrace = thread.getStackTrace();
157 final Thread stuckThread = lookForStuckThread ? getStuckThread(thread) : null;
158 Exception currThreadException = new TestTimedOutException(timeout, timeUnit);
159 if (stackTrace != null) {
160 currThreadException.setStackTrace(stackTrace);
161 thread.interrupt();
162 }
163 if (stuckThread != null) {
164 Exception stuckThreadException =
165 new Exception ("Appears to be stuck in thread " +
166 stuckThread.getName());
167 stuckThreadException.setStackTrace(getStackTrace(stuckThread));
168 return new MultipleFailureException(
169 Arrays.<Throwable>asList(currThreadException, stuckThreadException));
170 } else {
171 return currThreadException;
172 }
173 }
174
175
176
177
178
179
180
181 private StackTraceElement[] getStackTrace(Thread thread) {
182 try {
183 return thread.getStackTrace();
184 } catch (SecurityException e) {
185 return new StackTraceElement[0];
186 }
187 }
188
189
190
191
192
193
194
195
196
197
198
199 private Thread getStuckThread(Thread mainThread) {
200 if (threadGroup == null) {
201 return null;
202 }
203 Thread[] threadsInGroup = getThreadArray(threadGroup);
204 if (threadsInGroup == null) {
205 return null;
206 }
207
208
209
210
211
212
213 Thread stuckThread = null;
214 long maxCpuTime = 0;
215 for (Thread thread : threadsInGroup) {
216 if (thread.getState() == Thread.State.RUNNABLE) {
217 long threadCpuTime = cpuTime(thread);
218 if (stuckThread == null || threadCpuTime > maxCpuTime) {
219 stuckThread = thread;
220 maxCpuTime = threadCpuTime;
221 }
222 }
223 }
224 return (stuckThread == mainThread) ? null : stuckThread;
225 }
226
227
228
229
230
231
232
233
234
235 private Thread[] getThreadArray(ThreadGroup group) {
236 final int count = group.activeCount();
237 int enumSize = Math.max(count * 2, 100);
238 int enumCount;
239 Thread[] threads;
240 int loopCount = 0;
241 while (true) {
242 threads = new Thread[enumSize];
243 enumCount = group.enumerate(threads);
244 if (enumCount < enumSize) {
245 break;
246 }
247
248
249
250 enumSize += 100;
251 if (++loopCount >= 5) {
252 return null;
253 }
254
255
256 }
257 return copyThreads(threads, enumCount);
258 }
259
260
261
262
263
264
265
266
267 private Thread[] copyThreads(Thread[] threads, int count) {
268 int length = Math.min(count, threads.length);
269 Thread[] result = new Thread[length];
270 for (int i = 0; i < length; i++) {
271 result[i] = threads[i];
272 }
273 return result;
274 }
275
276
277
278
279
280
281 private long cpuTime (Thread thr) {
282 ThreadMXBean mxBean = ManagementFactory.getThreadMXBean();
283 if (mxBean.isThreadCpuTimeSupported()) {
284 try {
285 return mxBean.getThreadCpuTime(thr.getId());
286 } catch (UnsupportedOperationException e) {
287 }
288 }
289 return 0;
290 }
291
292 private class CallableStatement implements Callable<Throwable> {
293 private final CountDownLatch startLatch = new CountDownLatch(1);
294
295 public Throwable call() throws Exception {
296 try {
297 startLatch.countDown();
298 originalStatement.evaluate();
299 } catch (Exception e) {
300 throw e;
301 } catch (Throwable e) {
302 return e;
303 }
304 return null;
305 }
306
307 public void awaitStarted() throws InterruptedException {
308 startLatch.await();
309 }
310 }
311 }