001    package org.junit.runners.model;
002    
003    import static java.lang.reflect.Modifier.isStatic;
004    import static org.junit.internal.MethodSorter.NAME_ASCENDING;
005    
006    import java.lang.annotation.Annotation;
007    import java.lang.reflect.Constructor;
008    import java.lang.reflect.Field;
009    import java.lang.reflect.Method;
010    import java.lang.reflect.Modifier;
011    import java.util.ArrayList;
012    import java.util.Arrays;
013    import java.util.Collections;
014    import java.util.Comparator;
015    import java.util.LinkedHashMap;
016    import java.util.LinkedHashSet;
017    import java.util.List;
018    import java.util.Map;
019    import java.util.Set;
020    
021    import org.junit.Assert;
022    import org.junit.Before;
023    import org.junit.BeforeClass;
024    import org.junit.internal.MethodSorter;
025    
026    /**
027     * Wraps a class to be run, providing method validation and annotation searching
028     *
029     * @since 4.5
030     */
031    public class TestClass implements Annotatable {
032        private static final FieldComparator FIELD_COMPARATOR = new FieldComparator();
033        private static final MethodComparator METHOD_COMPARATOR = new MethodComparator();
034    
035        private final Class<?> clazz;
036        private final Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations;
037        private final Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations;
038    
039        /**
040         * Creates a {@code TestClass} wrapping {@code clazz}. Each time this
041         * constructor executes, the class is scanned for annotations, which can be
042         * an expensive process (we hope in future JDK's it will not be.) Therefore,
043         * try to share instances of {@code TestClass} where possible.
044         */
045        public TestClass(Class<?> clazz) {
046            this.clazz = clazz;
047            if (clazz != null && clazz.getConstructors().length > 1) {
048                throw new IllegalArgumentException(
049                        "Test class can only have one constructor");
050            }
051    
052            Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations =
053                    new LinkedHashMap<Class<? extends Annotation>, List<FrameworkMethod>>();
054            Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations =
055                    new LinkedHashMap<Class<? extends Annotation>, List<FrameworkField>>();
056    
057            scanAnnotatedMembers(methodsForAnnotations, fieldsForAnnotations);
058    
059            this.methodsForAnnotations = makeDeeplyUnmodifiable(methodsForAnnotations);
060            this.fieldsForAnnotations = makeDeeplyUnmodifiable(fieldsForAnnotations);
061        }
062    
063        protected void scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations) {
064            for (Class<?> eachClass : getSuperClasses(clazz)) {
065                for (Method eachMethod : MethodSorter.getDeclaredMethods(eachClass)) {
066                    addToAnnotationLists(new FrameworkMethod(eachMethod), methodsForAnnotations);
067                }
068                // ensuring fields are sorted to make sure that entries are inserted
069                // and read from fieldForAnnotations in a deterministic order
070                for (Field eachField : getSortedDeclaredFields(eachClass)) {
071                    addToAnnotationLists(new FrameworkField(eachField), fieldsForAnnotations);
072                }
073            }
074        }
075    
076        private static Field[] getSortedDeclaredFields(Class<?> clazz) {
077            Field[] declaredFields = clazz.getDeclaredFields();
078            Arrays.sort(declaredFields, FIELD_COMPARATOR);
079            return declaredFields;
080        }
081    
082        protected static <T extends FrameworkMember<T>> void addToAnnotationLists(T member,
083                Map<Class<? extends Annotation>, List<T>> map) {
084            for (Annotation each : member.getAnnotations()) {
085                Class<? extends Annotation> type = each.annotationType();
086                List<T> members = getAnnotatedMembers(map, type, true);
087                if (member.isShadowedBy(members)) {
088                    return;
089                }
090                if (runsTopToBottom(type)) {
091                    members.add(0, member);
092                } else {
093                    members.add(member);
094                }
095            }
096        }
097    
098        private static <T extends FrameworkMember<T>> Map<Class<? extends Annotation>, List<T>>
099                makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source) {
100            LinkedHashMap<Class<? extends Annotation>, List<T>> copy =
101                    new LinkedHashMap<Class<? extends Annotation>, List<T>>();
102            for (Map.Entry<Class<? extends Annotation>, List<T>> entry : source.entrySet()) {
103                copy.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
104            }
105            return Collections.unmodifiableMap(copy);
106        }
107    
108        /**
109         * Returns, efficiently, all the non-overridden methods in this class and
110         * its superclasses that are annotated}.
111         * 
112         * @since 4.12
113         */
114        public List<FrameworkMethod> getAnnotatedMethods() {
115            List<FrameworkMethod> methods = collectValues(methodsForAnnotations);
116            Collections.sort(methods, METHOD_COMPARATOR);
117            return methods;
118        }
119    
120        /**
121         * Returns, efficiently, all the non-overridden methods in this class and
122         * its superclasses that are annotated with {@code annotationClass}.
123         */
124        public List<FrameworkMethod> getAnnotatedMethods(
125                Class<? extends Annotation> annotationClass) {
126            return Collections.unmodifiableList(getAnnotatedMembers(methodsForAnnotations, annotationClass, false));
127        }
128    
129        /**
130         * Returns, efficiently, all the non-overridden fields in this class and its
131         * superclasses that are annotated.
132         * 
133         * @since 4.12
134         */
135        public List<FrameworkField> getAnnotatedFields() {
136            return collectValues(fieldsForAnnotations);
137        }
138    
139        /**
140         * Returns, efficiently, all the non-overridden fields in this class and its
141         * superclasses that are annotated with {@code annotationClass}.
142         */
143        public List<FrameworkField> getAnnotatedFields(
144                Class<? extends Annotation> annotationClass) {
145            return Collections.unmodifiableList(getAnnotatedMembers(fieldsForAnnotations, annotationClass, false));
146        }
147    
148        private <T> List<T> collectValues(Map<?, List<T>> map) {
149            Set<T> values = new LinkedHashSet<T>();
150            for (List<T> additionalValues : map.values()) {
151                values.addAll(additionalValues);
152            }
153            return new ArrayList<T>(values);
154        }
155    
156        private static <T> List<T> getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map,
157                Class<? extends Annotation> type, boolean fillIfAbsent) {
158            if (!map.containsKey(type) && fillIfAbsent) {
159                map.put(type, new ArrayList<T>());
160            }
161            List<T> members = map.get(type);
162            return members == null ? Collections.<T>emptyList() : members;
163        }
164    
165        private static boolean runsTopToBottom(Class<? extends Annotation> annotation) {
166            return annotation.equals(Before.class)
167                    || annotation.equals(BeforeClass.class);
168        }
169    
170        private static List<Class<?>> getSuperClasses(Class<?> testClass) {
171            ArrayList<Class<?>> results = new ArrayList<Class<?>>();
172            Class<?> current = testClass;
173            while (current != null) {
174                results.add(current);
175                current = current.getSuperclass();
176            }
177            return results;
178        }
179    
180        /**
181         * Returns the underlying Java class.
182         */
183        public Class<?> getJavaClass() {
184            return clazz;
185        }
186    
187        /**
188         * Returns the class's name.
189         */
190        public String getName() {
191            if (clazz == null) {
192                return "null";
193            }
194            return clazz.getName();
195        }
196    
197        /**
198         * Returns the only public constructor in the class, or throws an {@code
199         * AssertionError} if there are more or less than one.
200         */
201    
202        public Constructor<?> getOnlyConstructor() {
203            Constructor<?>[] constructors = clazz.getConstructors();
204            Assert.assertEquals(1, constructors.length);
205            return constructors[0];
206        }
207    
208        /**
209         * Returns the annotations on this class
210         */
211        public Annotation[] getAnnotations() {
212            if (clazz == null) {
213                return new Annotation[0];
214            }
215            return clazz.getAnnotations();
216        }
217    
218        public <T extends Annotation> T getAnnotation(Class<T> annotationType) {
219            if (clazz == null) {
220                return null;
221            }
222            return clazz.getAnnotation(annotationType);
223        }
224    
225        public <T> List<T> getAnnotatedFieldValues(Object test,
226                Class<? extends Annotation> annotationClass, Class<T> valueClass) {
227            List<T> results = new ArrayList<T>();
228            for (FrameworkField each : getAnnotatedFields(annotationClass)) {
229                try {
230                    Object fieldValue = each.get(test);
231                    if (valueClass.isInstance(fieldValue)) {
232                        results.add(valueClass.cast(fieldValue));
233                    }
234                } catch (IllegalAccessException e) {
235                    throw new RuntimeException(
236                            "How did getFields return a field we couldn't access?", e);
237                }
238            }
239            return results;
240        }
241    
242        public <T> List<T> getAnnotatedMethodValues(Object test,
243                Class<? extends Annotation> annotationClass, Class<T> valueClass) {
244            List<T> results = new ArrayList<T>();
245            for (FrameworkMethod each : getAnnotatedMethods(annotationClass)) {
246                try {
247                    /*
248                     * A method annotated with @Rule may return a @TestRule or a @MethodRule,
249                     * we cannot call the method to check whether the return type matches our
250                     * expectation i.e. subclass of valueClass. If we do that then the method 
251                     * will be invoked twice and we do not want to do that. So we first check
252                     * whether return type matches our expectation and only then call the method
253                     * to fetch the MethodRule
254                     */
255                    if (valueClass.isAssignableFrom(each.getReturnType())) {
256                        Object fieldValue = each.invokeExplosively(test);
257                        results.add(valueClass.cast(fieldValue));
258                    }
259                } catch (Throwable e) {
260                    throw new RuntimeException(
261                            "Exception in " + each.getName(), e);
262                }
263            }
264            return results;
265        }
266    
267        public boolean isPublic() {
268            return Modifier.isPublic(clazz.getModifiers());
269        }
270    
271        public boolean isANonStaticInnerClass() {
272            return clazz.isMemberClass() && !isStatic(clazz.getModifiers());
273        }
274    
275        @Override
276        public int hashCode() {
277            return (clazz == null) ? 0 : clazz.hashCode();
278        }
279    
280        @Override
281        public boolean equals(Object obj) {
282            if (this == obj) {
283                return true;
284            }
285            if (obj == null) {
286                return false;
287            }
288            if (getClass() != obj.getClass()) {
289                return false;
290            }
291            TestClass other = (TestClass) obj;
292            return clazz == other.clazz;
293        }
294    
295        /**
296         * Compares two fields by its name.
297         */
298        private static class FieldComparator implements Comparator<Field> {
299            public int compare(Field left, Field right) {
300                return left.getName().compareTo(right.getName());
301            }
302        }
303    
304        /**
305         * Compares two methods by its name.
306         */
307        private static class MethodComparator implements
308                Comparator<FrameworkMethod> {
309            public int compare(FrameworkMethod left, FrameworkMethod right) {
310                return NAME_ASCENDING.compare(left.getMethod(), right.getMethod());
311            }
312        }
313    }