View Javadoc
1   package org.junit.internal.runners.statements;
2   
3   import java.lang.management.ManagementFactory;
4   import java.lang.management.ThreadMXBean;
5   import java.util.Arrays;
6   import java.util.concurrent.Callable;
7   import java.util.concurrent.CountDownLatch;
8   import java.util.concurrent.ExecutionException;
9   import java.util.concurrent.FutureTask;
10  import java.util.concurrent.TimeUnit;
11  import java.util.concurrent.TimeoutException;
12  
13  import org.junit.runners.model.MultipleFailureException;
14  import org.junit.runners.model.Statement;
15  import org.junit.runners.model.TestTimedOutException;
16  
17  public class FailOnTimeout extends Statement {
18      private final Statement originalStatement;
19      private final TimeUnit timeUnit;
20      private final long timeout;
21      private final boolean lookForStuckThread;
22      private volatile ThreadGroup threadGroup = null;
23  
24      /**
25       * Returns a new builder for building an instance.
26       *
27       * @since 4.12
28       */
29      public static Builder builder() {
30          return new Builder();
31      }
32  
33      /**
34       * Creates an instance wrapping the given statement with the given timeout in milliseconds.
35       *
36       * @param statement the statement to wrap
37       * @param timeoutMillis the timeout in milliseconds
38       * @deprecated use {@link #builder()} instead.
39       */
40      @Deprecated
41      public FailOnTimeout(Statement statement, long timeoutMillis) {
42          this(builder().withTimeout(timeoutMillis, TimeUnit.MILLISECONDS), statement);
43      }
44  
45      private FailOnTimeout(Builder builder, Statement statement) {
46          originalStatement = statement;
47          timeout = builder.timeout;
48          timeUnit = builder.unit;
49          lookForStuckThread = builder.lookForStuckThread;
50      }
51  
52      /**
53       * Builder for {@link FailOnTimeout}.
54       *
55       * @since 4.12
56       */
57      public static class Builder {
58          private boolean lookForStuckThread = false;
59          private long timeout = 0;
60          private TimeUnit unit = TimeUnit.SECONDS;
61  
62          private Builder() {
63          }
64  
65          /**
66           * Specifies the time to wait before timing out the test.
67           *
68           * <p>If this is not called, or is called with a {@code timeout} of
69           * {@code 0}, the returned {@code Statement} will wait forever for the
70           * test to complete, however the test will still launch from a separate
71           * thread. This can be useful for disabling timeouts in environments
72           * where they are dynamically set based on some property.
73           *
74           * @param timeout the maximum time to wait
75           * @param unit the time unit of the {@code timeout} argument
76           * @return {@code this} for method chaining.
77           */
78          public Builder withTimeout(long timeout, TimeUnit unit) {
79              if (timeout < 0) {
80                  throw new IllegalArgumentException("timeout must be non-negative");
81              }
82              if (unit == null) {
83                  throw new NullPointerException("TimeUnit cannot be null");
84              }
85              this.timeout = timeout;
86              this.unit = unit;
87              return this;
88          }
89  
90          /**
91           * Specifies whether to look for a stuck thread.  If a timeout occurs and this
92           * feature is enabled, the test will look for a thread that appears to be stuck
93           * and dump its backtrace.  This feature is experimental.  Behavior may change
94           * after the 4.12 release in response to feedback.
95           *
96           * @param enable {@code true} to enable the feature
97           * @return {@code this} for method chaining.
98           */
99          public Builder withLookingForStuckThread(boolean enable) {
100             this.lookForStuckThread = enable;
101             return this;
102         }
103 
104         /**
105          * Builds a {@link FailOnTimeout} instance using the values in this builder,
106          * wrapping the given statement.
107          *
108          * @param statement
109          */
110         public FailOnTimeout build(Statement statement) {
111             if (statement == null) {
112                 throw new NullPointerException("statement cannot be null");
113             }
114             return new FailOnTimeout(this, statement);
115         }
116     }
117 
118     @Override
119     public void evaluate() throws Throwable {
120         CallableStatement callable = new CallableStatement();
121         FutureTask<Throwable> task = new FutureTask<Throwable>(callable);
122         threadGroup = new ThreadGroup("FailOnTimeoutGroup");
123         Thread thread = new Thread(threadGroup, task, "Time-limited test");
124         thread.setDaemon(true);
125         thread.start();
126         callable.awaitStarted();
127         Throwable throwable = getResult(task, thread);
128         if (throwable != null) {
129             throw throwable;
130         }
131     }
132 
133     /**
134      * Wait for the test task, returning the exception thrown by the test if the
135      * test failed, an exception indicating a timeout if the test timed out, or
136      * {@code null} if the test passed.
137      */
138     private Throwable getResult(FutureTask<Throwable> task, Thread thread) {
139         try {
140             if (timeout > 0) {
141                 return task.get(timeout, timeUnit);
142             } else {
143                 return task.get();
144             }
145         } catch (InterruptedException e) {
146             return e; // caller will re-throw; no need to call Thread.interrupt()
147         } catch (ExecutionException e) {
148             // test failed; have caller re-throw the exception thrown by the test
149             return e.getCause();
150         } catch (TimeoutException e) {
151             return createTimeoutException(thread);
152         }
153     }
154 
155     private Exception createTimeoutException(Thread thread) {
156         StackTraceElement[] stackTrace = thread.getStackTrace();
157         final Thread stuckThread = lookForStuckThread ? getStuckThread(thread) : null;
158         Exception currThreadException = new TestTimedOutException(timeout, timeUnit);
159         if (stackTrace != null) {
160             currThreadException.setStackTrace(stackTrace);
161             thread.interrupt();
162         }
163         if (stuckThread != null) {
164             Exception stuckThreadException = 
165                 new Exception ("Appears to be stuck in thread " +
166                                stuckThread.getName());
167             stuckThreadException.setStackTrace(getStackTrace(stuckThread));
168             return new MultipleFailureException(
169                 Arrays.<Throwable>asList(currThreadException, stuckThreadException));
170         } else {
171             return currThreadException;
172         }
173     }
174 
175     /**
176      * Retrieves the stack trace for a given thread.
177      * @param thread The thread whose stack is to be retrieved.
178      * @return The stack trace; returns a zero-length array if the thread has 
179      * terminated or the stack cannot be retrieved for some other reason.
180      */
181     private StackTraceElement[] getStackTrace(Thread thread) {
182         try {
183             return thread.getStackTrace();
184         } catch (SecurityException e) {
185             return new StackTraceElement[0];
186         }
187     }
188 
189     /**
190      * Determines whether the test appears to be stuck in some thread other than
191      * the "main thread" (the one created to run the test).  This feature is experimental.
192      * Behavior may change after the 4.12 release in response to feedback.
193      * @param mainThread The main thread created by {@code evaluate()}
194      * @return The thread which appears to be causing the problem, if different from
195      * {@code mainThread}, or {@code null} if the main thread appears to be the
196      * problem or if the thread cannot be determined.  The return value is never equal 
197      * to {@code mainThread}.
198      */
199     private Thread getStuckThread(Thread mainThread) {
200         if (threadGroup == null) {
201             return null;
202         }
203         Thread[] threadsInGroup = getThreadArray(threadGroup);
204         if (threadsInGroup == null) {
205             return null;
206         }
207 
208         // Now that we have all the threads in the test's thread group: Assume that
209         // any thread we're "stuck" in is RUNNABLE.  Look for all RUNNABLE threads. 
210         // If just one, we return that (unless it equals threadMain).  If there's more
211         // than one, pick the one that's using the most CPU time, if this feature is
212         // supported.
213         Thread stuckThread = null;
214         long maxCpuTime = 0;
215         for (Thread thread : threadsInGroup) {
216             if (thread.getState() == Thread.State.RUNNABLE) {
217                 long threadCpuTime = cpuTime(thread);
218                 if (stuckThread == null || threadCpuTime > maxCpuTime) {
219                     stuckThread = thread;
220                     maxCpuTime = threadCpuTime;
221                 }
222             }               
223         }
224         return (stuckThread == mainThread) ? null : stuckThread;
225     }
226 
227     /**
228      * Returns all active threads belonging to a thread group.  
229      * @param group The thread group.
230      * @return The active threads in the thread group.  The result should be a
231      * complete list of the active threads at some point in time.  Returns {@code null}
232      * if this cannot be determined, e.g. because new threads are being created at an
233      * extremely fast rate.
234      */
235     private Thread[] getThreadArray(ThreadGroup group) {
236         final int count = group.activeCount(); // this is just an estimate
237         int enumSize = Math.max(count * 2, 100);
238         int enumCount;
239         Thread[] threads;
240         int loopCount = 0;
241         while (true) {
242             threads = new Thread[enumSize];
243             enumCount = group.enumerate(threads);
244             if (enumCount < enumSize) {
245                 break;
246             }
247             // if there are too many threads to fit into the array, enumerate's result
248             // is >= the array's length; therefore we can't trust that it returned all
249             // the threads.  Try again.
250             enumSize += 100;
251             if (++loopCount >= 5) {
252                 return null;
253             }
254             // threads are proliferating too fast for us.  Bail before we get into 
255             // trouble.
256         }
257         return copyThreads(threads, enumCount);
258     }
259 
260     /**
261      * Returns an array of the first {@code count} Threads in {@code threads}. 
262      * (Use instead of Arrays.copyOf to maintain compatibility with Java 1.5.)
263      * @param threads The source array.
264      * @param count The maximum length of the result array.
265      * @return The first {@count} (at most) elements of {@code threads}.
266      */
267     private Thread[] copyThreads(Thread[] threads, int count) {
268         int length = Math.min(count, threads.length);
269         Thread[] result = new Thread[length];
270         for (int i = 0; i < length; i++) {
271             result[i] = threads[i];
272         }
273         return result;
274     }
275 
276     /**
277      * Returns the CPU time used by a thread, if possible.
278      * @param thr The thread to query.
279      * @return The CPU time used by {@code thr}, or 0 if it cannot be determined.
280      */
281     private long cpuTime (Thread thr) {
282         ThreadMXBean mxBean = ManagementFactory.getThreadMXBean();
283         if (mxBean.isThreadCpuTimeSupported()) {
284             try {
285                 return mxBean.getThreadCpuTime(thr.getId());
286             } catch (UnsupportedOperationException e) {
287             }
288         }
289         return 0;
290     }
291 
292     private class CallableStatement implements Callable<Throwable> {
293         private final CountDownLatch startLatch = new CountDownLatch(1);
294 
295         public Throwable call() throws Exception {
296             try {
297                 startLatch.countDown();
298                 originalStatement.evaluate();
299             } catch (Exception e) {
300                 throw e;
301             } catch (Throwable e) {
302                 return e;
303             }
304             return null;
305         }
306 
307         public void awaitStarted() throws InterruptedException {
308             startLatch.await();
309         }
310     }
311 }