001    package org.junit.runners.model;
002    
003    import static java.lang.reflect.Modifier.isStatic;
004    import static org.junit.internal.MethodSorter.NAME_ASCENDING;
005    
006    import java.lang.annotation.Annotation;
007    import java.lang.reflect.Constructor;
008    import java.lang.reflect.Field;
009    import java.lang.reflect.Method;
010    import java.lang.reflect.Modifier;
011    import java.util.ArrayList;
012    import java.util.Arrays;
013    import java.util.Collections;
014    import java.util.Comparator;
015    import java.util.LinkedHashMap;
016    import java.util.LinkedHashSet;
017    import java.util.List;
018    import java.util.Map;
019    import java.util.Set;
020    
021    import org.junit.Assert;
022    import org.junit.Before;
023    import org.junit.BeforeClass;
024    import org.junit.internal.MethodSorter;
025    
026    /**
027     * Wraps a class to be run, providing method validation and annotation searching
028     *
029     * @since 4.5
030     */
031    public class TestClass implements Annotatable {
032        private static final FieldComparator FIELD_COMPARATOR = new FieldComparator();
033        private static final MethodComparator METHOD_COMPARATOR = new MethodComparator();
034    
035        private final Class<?> clazz;
036        private final Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations;
037        private final Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations;
038    
039        /**
040         * Creates a {@code TestClass} wrapping {@code clazz}. Each time this
041         * constructor executes, the class is scanned for annotations, which can be
042         * an expensive process (we hope in future JDK's it will not be.) Therefore,
043         * try to share instances of {@code TestClass} where possible.
044         */
045        public TestClass(Class<?> clazz) {
046            this.clazz = clazz;
047            if (clazz != null && clazz.getConstructors().length > 1) {
048                throw new IllegalArgumentException(
049                        "Test class can only have one constructor");
050            }
051    
052            Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations =
053                    new LinkedHashMap<Class<? extends Annotation>, List<FrameworkMethod>>();
054            Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations =
055                    new LinkedHashMap<Class<? extends Annotation>, List<FrameworkField>>();
056    
057            scanAnnotatedMembers(methodsForAnnotations, fieldsForAnnotations);
058    
059            this.methodsForAnnotations = makeDeeplyUnmodifiable(methodsForAnnotations);
060            this.fieldsForAnnotations = makeDeeplyUnmodifiable(fieldsForAnnotations);
061        }
062    
063        protected void scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations) {
064            for (Class<?> eachClass : getSuperClasses(clazz)) {
065                for (Method eachMethod : MethodSorter.getDeclaredMethods(eachClass)) {
066                    addToAnnotationLists(new FrameworkMethod(eachMethod), methodsForAnnotations);
067                }
068                // ensuring fields are sorted to make sure that entries are inserted
069                // and read from fieldForAnnotations in a deterministic order
070                for (Field eachField : getSortedDeclaredFields(eachClass)) {
071                    addToAnnotationLists(new FrameworkField(eachField), fieldsForAnnotations);
072                }
073            }
074        }
075    
076        private static Field[] getSortedDeclaredFields(Class<?> clazz) {
077            Field[] declaredFields = clazz.getDeclaredFields();
078            Arrays.sort(declaredFields, FIELD_COMPARATOR);
079            return declaredFields;
080        }
081    
082        protected static <T extends FrameworkMember<T>> void addToAnnotationLists(T member,
083                Map<Class<? extends Annotation>, List<T>> map) {
084            for (Annotation each : member.getAnnotations()) {
085                Class<? extends Annotation> type = each.annotationType();
086                List<T> members = getAnnotatedMembers(map, type, true);
087                T memberToAdd = member.handlePossibleBridgeMethod(members);
088                if (memberToAdd == null) {
089                    return;
090                }
091                if (runsTopToBottom(type)) {
092                    members.add(0, memberToAdd);
093                } else {
094                    members.add(memberToAdd);
095                }
096            }
097        }
098    
099        private static <T extends FrameworkMember<T>> Map<Class<? extends Annotation>, List<T>>
100                makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source) {
101            Map<Class<? extends Annotation>, List<T>> copy =
102                    new LinkedHashMap<Class<? extends Annotation>, List<T>>();
103            for (Map.Entry<Class<? extends Annotation>, List<T>> entry : source.entrySet()) {
104                copy.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
105            }
106            return Collections.unmodifiableMap(copy);
107        }
108    
109        /**
110         * Returns, efficiently, all the non-overridden methods in this class and
111         * its superclasses that are annotated}.
112         * 
113         * @since 4.12
114         */
115        public List<FrameworkMethod> getAnnotatedMethods() {
116            List<FrameworkMethod> methods = collectValues(methodsForAnnotations);
117            Collections.sort(methods, METHOD_COMPARATOR);
118            return methods;
119        }
120    
121        /**
122         * Returns, efficiently, all the non-overridden methods in this class and
123         * its superclasses that are annotated with {@code annotationClass}.
124         */
125        public List<FrameworkMethod> getAnnotatedMethods(
126                Class<? extends Annotation> annotationClass) {
127            return Collections.unmodifiableList(getAnnotatedMembers(methodsForAnnotations, annotationClass, false));
128        }
129    
130        /**
131         * Returns, efficiently, all the non-overridden fields in this class and its
132         * superclasses that are annotated.
133         * 
134         * @since 4.12
135         */
136        public List<FrameworkField> getAnnotatedFields() {
137            return collectValues(fieldsForAnnotations);
138        }
139    
140        /**
141         * Returns, efficiently, all the non-overridden fields in this class and its
142         * superclasses that are annotated with {@code annotationClass}.
143         */
144        public List<FrameworkField> getAnnotatedFields(
145                Class<? extends Annotation> annotationClass) {
146            return Collections.unmodifiableList(getAnnotatedMembers(fieldsForAnnotations, annotationClass, false));
147        }
148    
149        private <T> List<T> collectValues(Map<?, List<T>> map) {
150            Set<T> values = new LinkedHashSet<T>();
151            for (List<T> additionalValues : map.values()) {
152                values.addAll(additionalValues);
153            }
154            return new ArrayList<T>(values);
155        }
156    
157        private static <T> List<T> getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map,
158                Class<? extends Annotation> type, boolean fillIfAbsent) {
159            if (!map.containsKey(type) && fillIfAbsent) {
160                map.put(type, new ArrayList<T>());
161            }
162            List<T> members = map.get(type);
163            return members == null ? Collections.<T>emptyList() : members;
164        }
165    
166        private static boolean runsTopToBottom(Class<? extends Annotation> annotation) {
167            return annotation.equals(Before.class)
168                    || annotation.equals(BeforeClass.class);
169        }
170    
171        private static List<Class<?>> getSuperClasses(Class<?> testClass) {
172            List<Class<?>> results = new ArrayList<Class<?>>();
173            Class<?> current = testClass;
174            while (current != null) {
175                results.add(current);
176                current = current.getSuperclass();
177            }
178            return results;
179        }
180    
181        /**
182         * Returns the underlying Java class.
183         */
184        public Class<?> getJavaClass() {
185            return clazz;
186        }
187    
188        /**
189         * Returns the class's name.
190         */
191        public String getName() {
192            if (clazz == null) {
193                return "null";
194            }
195            return clazz.getName();
196        }
197    
198        /**
199         * Returns the only public constructor in the class, or throws an {@code
200         * AssertionError} if there are more or less than one.
201         */
202    
203        public Constructor<?> getOnlyConstructor() {
204            Constructor<?>[] constructors = clazz.getConstructors();
205            Assert.assertEquals(1, constructors.length);
206            return constructors[0];
207        }
208    
209        /**
210         * Returns the annotations on this class
211         */
212        public Annotation[] getAnnotations() {
213            if (clazz == null) {
214                return new Annotation[0];
215            }
216            return clazz.getAnnotations();
217        }
218    
219        public <T extends Annotation> T getAnnotation(Class<T> annotationType) {
220            if (clazz == null) {
221                return null;
222            }
223            return clazz.getAnnotation(annotationType);
224        }
225    
226        public <T> List<T> getAnnotatedFieldValues(Object test,
227                Class<? extends Annotation> annotationClass, Class<T> valueClass) {
228            final List<T> results = new ArrayList<T>();
229            collectAnnotatedFieldValues(test, annotationClass, valueClass,
230                    new MemberValueConsumer<T>() {
231                        public void accept(FrameworkMember<?> member, T value) {
232                            results.add(value);
233                        }
234                    });
235            return results;
236        }
237    
238        /**
239         * Finds the fields annotated with the specified annotation and having the specified type,
240         * retrieves the values and passes those to the specified consumer.
241         *
242         * @since 4.13
243         */
244        public <T> void collectAnnotatedFieldValues(Object test,
245                Class<? extends Annotation> annotationClass, Class<T> valueClass,
246                MemberValueConsumer<T> consumer) {
247            for (FrameworkField each : getAnnotatedFields(annotationClass)) {
248                try {
249                    Object fieldValue = each.get(test);
250                    if (valueClass.isInstance(fieldValue)) {
251                        consumer.accept(each, valueClass.cast(fieldValue));
252                    }
253                } catch (IllegalAccessException e) {
254                    throw new RuntimeException(
255                            "How did getFields return a field we couldn't access?", e);
256                }
257            }
258        }
259    
260        public <T> List<T> getAnnotatedMethodValues(Object test,
261                Class<? extends Annotation> annotationClass, Class<T> valueClass) {
262            final List<T> results = new ArrayList<T>();
263            collectAnnotatedMethodValues(test, annotationClass, valueClass,
264                    new MemberValueConsumer<T>() {
265                        public void accept(FrameworkMember<?> member, T value) {
266                            results.add(value);
267                        }
268                    });
269            return results;
270        }
271    
272        /**
273         * Finds the methods annotated with the specified annotation and returning the specified type,
274         * invokes it and pass the return value to the specified consumer.
275         *
276         * @since 4.13
277         */
278        public <T> void collectAnnotatedMethodValues(Object test,
279                Class<? extends Annotation> annotationClass, Class<T> valueClass,
280                MemberValueConsumer<T> consumer) {
281            for (FrameworkMethod each : getAnnotatedMethods(annotationClass)) {
282                try {
283                    /*
284                     * A method annotated with @Rule may return a @TestRule or a @MethodRule,
285                     * we cannot call the method to check whether the return type matches our
286                     * expectation i.e. subclass of valueClass. If we do that then the method 
287                     * will be invoked twice and we do not want to do that. So we first check
288                     * whether return type matches our expectation and only then call the method
289                     * to fetch the MethodRule
290                     */
291                    if (valueClass.isAssignableFrom(each.getReturnType())) {
292                        Object fieldValue = each.invokeExplosively(test);
293                        consumer.accept(each, valueClass.cast(fieldValue));
294                    }
295                } catch (Throwable e) {
296                    throw new RuntimeException(
297                            "Exception in " + each.getName(), e);
298                }
299            }
300        }
301    
302        public boolean isPublic() {
303            return Modifier.isPublic(clazz.getModifiers());
304        }
305    
306        public boolean isANonStaticInnerClass() {
307            return clazz.isMemberClass() && !isStatic(clazz.getModifiers());
308        }
309    
310        @Override
311        public int hashCode() {
312            return (clazz == null) ? 0 : clazz.hashCode();
313        }
314    
315        @Override
316        public boolean equals(Object obj) {
317            if (this == obj) {
318                return true;
319            }
320            if (obj == null) {
321                return false;
322            }
323            if (getClass() != obj.getClass()) {
324                return false;
325            }
326            TestClass other = (TestClass) obj;
327            return clazz == other.clazz;
328        }
329    
330        /**
331         * Compares two fields by its name.
332         */
333        private static class FieldComparator implements Comparator<Field> {
334            public int compare(Field left, Field right) {
335                return left.getName().compareTo(right.getName());
336            }
337        }
338    
339        /**
340         * Compares two methods by its name.
341         */
342        private static class MethodComparator implements
343                Comparator<FrameworkMethod> {
344            public int compare(FrameworkMethod left, FrameworkMethod right) {
345                return NAME_ASCENDING.compare(left.getMethod(), right.getMethod());
346            }
347        }
348    }