001    package org.junit.experimental;
002    
003    import java.util.concurrent.ExecutorService;
004    import java.util.concurrent.Executors;
005    import java.util.concurrent.TimeUnit;
006    
007    import org.junit.runner.Computer;
008    import org.junit.runner.Runner;
009    import org.junit.runners.ParentRunner;
010    import org.junit.runners.model.InitializationError;
011    import org.junit.runners.model.RunnerBuilder;
012    import org.junit.runners.model.RunnerScheduler;
013    
014    public class ParallelComputer extends Computer {
015        private final boolean classes;
016    
017        private final boolean methods;
018    
019        public ParallelComputer(boolean classes, boolean methods) {
020            this.classes = classes;
021            this.methods = methods;
022        }
023    
024        public static Computer classes() {
025            return new ParallelComputer(true, false);
026        }
027    
028        public static Computer methods() {
029            return new ParallelComputer(false, true);
030        }
031    
032        private static Runner parallelize(Runner runner) {
033            if (runner instanceof ParentRunner) {
034                ((ParentRunner<?>) runner).setScheduler(new RunnerScheduler() {
035                    private final ExecutorService fService = Executors.newCachedThreadPool();
036    
037                    public void schedule(Runnable childStatement) {
038                        fService.submit(childStatement);
039                    }
040    
041                    public void finished() {
042                        try {
043                            fService.shutdown();
044                            fService.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
045                        } catch (InterruptedException e) {
046                            e.printStackTrace(System.err);
047                        }
048                    }
049                });
050            }
051            return runner;
052        }
053    
054        @Override
055        public Runner getSuite(RunnerBuilder builder, java.lang.Class<?>[] classes)
056                throws InitializationError {
057            Runner suite = super.getSuite(builder, classes);
058            return this.classes ? parallelize(suite) : suite;
059        }
060    
061        @Override
062        protected Runner getRunner(RunnerBuilder builder, Class<?> testClass)
063                throws Throwable {
064            Runner runner = super.getRunner(builder, testClass);
065            return methods ? parallelize(runner) : runner;
066        }
067    }