View Javadoc
1   package org.junit.runners.model;
2   
3   import static java.lang.reflect.Modifier.isStatic;
4   import static org.junit.internal.MethodSorter.NAME_ASCENDING;
5   
6   import java.lang.annotation.Annotation;
7   import java.lang.reflect.Constructor;
8   import java.lang.reflect.Field;
9   import java.lang.reflect.Method;
10  import java.lang.reflect.Modifier;
11  import java.util.ArrayList;
12  import java.util.Arrays;
13  import java.util.Collections;
14  import java.util.Comparator;
15  import java.util.LinkedHashMap;
16  import java.util.LinkedHashSet;
17  import java.util.List;
18  import java.util.Map;
19  import java.util.Set;
20  
21  import org.junit.Assert;
22  import org.junit.Before;
23  import org.junit.BeforeClass;
24  import org.junit.internal.MethodSorter;
25  
26  /**
27   * Wraps a class to be run, providing method validation and annotation searching
28   *
29   * @since 4.5
30   */
31  public class TestClass implements Annotatable {
32      private static final FieldComparator FIELD_COMPARATOR = new FieldComparator();
33      private static final MethodComparator METHOD_COMPARATOR = new MethodComparator();
34  
35      private final Class<?> clazz;
36      private final Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations;
37      private final Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations;
38  
39      /**
40       * Creates a {@code TestClass} wrapping {@code clazz}. Each time this
41       * constructor executes, the class is scanned for annotations, which can be
42       * an expensive process (we hope in future JDK's it will not be.) Therefore,
43       * try to share instances of {@code TestClass} where possible.
44       */
45      public TestClass(Class<?> clazz) {
46          this.clazz = clazz;
47          if (clazz != null && clazz.getConstructors().length > 1) {
48              throw new IllegalArgumentException(
49                      "Test class can only have one constructor");
50          }
51  
52          Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations =
53                  new LinkedHashMap<Class<? extends Annotation>, List<FrameworkMethod>>();
54          Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations =
55                  new LinkedHashMap<Class<? extends Annotation>, List<FrameworkField>>();
56  
57          scanAnnotatedMembers(methodsForAnnotations, fieldsForAnnotations);
58  
59          this.methodsForAnnotations = makeDeeplyUnmodifiable(methodsForAnnotations);
60          this.fieldsForAnnotations = makeDeeplyUnmodifiable(fieldsForAnnotations);
61      }
62  
63      protected void scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations) {
64          for (Class<?> eachClass : getSuperClasses(clazz)) {
65              for (Method eachMethod : MethodSorter.getDeclaredMethods(eachClass)) {
66                  addToAnnotationLists(new FrameworkMethod(eachMethod), methodsForAnnotations);
67              }
68              // ensuring fields are sorted to make sure that entries are inserted
69              // and read from fieldForAnnotations in a deterministic order
70              for (Field eachField : getSortedDeclaredFields(eachClass)) {
71                  addToAnnotationLists(new FrameworkField(eachField), fieldsForAnnotations);
72              }
73          }
74      }
75  
76      private static Field[] getSortedDeclaredFields(Class<?> clazz) {
77          Field[] declaredFields = clazz.getDeclaredFields();
78          Arrays.sort(declaredFields, FIELD_COMPARATOR);
79          return declaredFields;
80      }
81  
82      protected static <T extends FrameworkMember<T>> void addToAnnotationLists(T member,
83              Map<Class<? extends Annotation>, List<T>> map) {
84          for (Annotation each : member.getAnnotations()) {
85              Class<? extends Annotation> type = each.annotationType();
86              List<T> members = getAnnotatedMembers(map, type, true);
87              if (member.isShadowedBy(members)) {
88                  return;
89              }
90              if (runsTopToBottom(type)) {
91                  members.add(0, member);
92              } else {
93                  members.add(member);
94              }
95          }
96      }
97  
98      private static <T extends FrameworkMember<T>> Map<Class<? extends Annotation>, List<T>>
99              makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source) {
100         LinkedHashMap<Class<? extends Annotation>, List<T>> copy =
101                 new LinkedHashMap<Class<? extends Annotation>, List<T>>();
102         for (Map.Entry<Class<? extends Annotation>, List<T>> entry : source.entrySet()) {
103             copy.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
104         }
105         return Collections.unmodifiableMap(copy);
106     }
107 
108     /**
109      * Returns, efficiently, all the non-overridden methods in this class and
110      * its superclasses that are annotated}.
111      * 
112      * @since 4.12
113      */
114     public List<FrameworkMethod> getAnnotatedMethods() {
115         List<FrameworkMethod> methods = collectValues(methodsForAnnotations);
116         Collections.sort(methods, METHOD_COMPARATOR);
117         return methods;
118     }
119 
120     /**
121      * Returns, efficiently, all the non-overridden methods in this class and
122      * its superclasses that are annotated with {@code annotationClass}.
123      */
124     public List<FrameworkMethod> getAnnotatedMethods(
125             Class<? extends Annotation> annotationClass) {
126         return Collections.unmodifiableList(getAnnotatedMembers(methodsForAnnotations, annotationClass, false));
127     }
128 
129     /**
130      * Returns, efficiently, all the non-overridden fields in this class and its
131      * superclasses that are annotated.
132      * 
133      * @since 4.12
134      */
135     public List<FrameworkField> getAnnotatedFields() {
136         return collectValues(fieldsForAnnotations);
137     }
138 
139     /**
140      * Returns, efficiently, all the non-overridden fields in this class and its
141      * superclasses that are annotated with {@code annotationClass}.
142      */
143     public List<FrameworkField> getAnnotatedFields(
144             Class<? extends Annotation> annotationClass) {
145         return Collections.unmodifiableList(getAnnotatedMembers(fieldsForAnnotations, annotationClass, false));
146     }
147 
148     private <T> List<T> collectValues(Map<?, List<T>> map) {
149         Set<T> values = new LinkedHashSet<T>();
150         for (List<T> additionalValues : map.values()) {
151             values.addAll(additionalValues);
152         }
153         return new ArrayList<T>(values);
154     }
155 
156     private static <T> List<T> getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map,
157             Class<? extends Annotation> type, boolean fillIfAbsent) {
158         if (!map.containsKey(type) && fillIfAbsent) {
159             map.put(type, new ArrayList<T>());
160         }
161         List<T> members = map.get(type);
162         return members == null ? Collections.<T>emptyList() : members;
163     }
164 
165     private static boolean runsTopToBottom(Class<? extends Annotation> annotation) {
166         return annotation.equals(Before.class)
167                 || annotation.equals(BeforeClass.class);
168     }
169 
170     private static List<Class<?>> getSuperClasses(Class<?> testClass) {
171         ArrayList<Class<?>> results = new ArrayList<Class<?>>();
172         Class<?> current = testClass;
173         while (current != null) {
174             results.add(current);
175             current = current.getSuperclass();
176         }
177         return results;
178     }
179 
180     /**
181      * Returns the underlying Java class.
182      */
183     public Class<?> getJavaClass() {
184         return clazz;
185     }
186 
187     /**
188      * Returns the class's name.
189      */
190     public String getName() {
191         if (clazz == null) {
192             return "null";
193         }
194         return clazz.getName();
195     }
196 
197     /**
198      * Returns the only public constructor in the class, or throws an {@code
199      * AssertionError} if there are more or less than one.
200      */
201 
202     public Constructor<?> getOnlyConstructor() {
203         Constructor<?>[] constructors = clazz.getConstructors();
204         Assert.assertEquals(1, constructors.length);
205         return constructors[0];
206     }
207 
208     /**
209      * Returns the annotations on this class
210      */
211     public Annotation[] getAnnotations() {
212         if (clazz == null) {
213             return new Annotation[0];
214         }
215         return clazz.getAnnotations();
216     }
217 
218     public <T extends Annotation> T getAnnotation(Class<T> annotationType) {
219         if (clazz == null) {
220             return null;
221         }
222         return clazz.getAnnotation(annotationType);
223     }
224 
225     public <T> List<T> getAnnotatedFieldValues(Object test,
226             Class<? extends Annotation> annotationClass, Class<T> valueClass) {
227         List<T> results = new ArrayList<T>();
228         for (FrameworkField each : getAnnotatedFields(annotationClass)) {
229             try {
230                 Object fieldValue = each.get(test);
231                 if (valueClass.isInstance(fieldValue)) {
232                     results.add(valueClass.cast(fieldValue));
233                 }
234             } catch (IllegalAccessException e) {
235                 throw new RuntimeException(
236                         "How did getFields return a field we couldn't access?", e);
237             }
238         }
239         return results;
240     }
241 
242     public <T> List<T> getAnnotatedMethodValues(Object test,
243             Class<? extends Annotation> annotationClass, Class<T> valueClass) {
244         List<T> results = new ArrayList<T>();
245         for (FrameworkMethod each : getAnnotatedMethods(annotationClass)) {
246             try {
247                 Object fieldValue = each.invokeExplosively(test);
248                 if (valueClass.isInstance(fieldValue)) {
249                     results.add(valueClass.cast(fieldValue));
250                 }
251             } catch (Throwable e) {
252                 throw new RuntimeException(
253                         "Exception in " + each.getName(), e);
254             }
255         }
256         return results;
257     }
258 
259     public boolean isPublic() {
260         return Modifier.isPublic(clazz.getModifiers());
261     }
262 
263     public boolean isANonStaticInnerClass() {
264         return clazz.isMemberClass() && !isStatic(clazz.getModifiers());
265     }
266 
267     @Override
268     public int hashCode() {
269         return (clazz == null) ? 0 : clazz.hashCode();
270     }
271 
272     @Override
273     public boolean equals(Object obj) {
274         if (this == obj) {
275             return true;
276         }
277         if (obj == null) {
278             return false;
279         }
280         if (getClass() != obj.getClass()) {
281             return false;
282         }
283         TestClass other = (TestClass) obj;
284         return clazz == other.clazz;
285     }
286 
287     /**
288      * Compares two fields by its name.
289      */
290     private static class FieldComparator implements Comparator<Field> {
291         public int compare(Field left, Field right) {
292             return left.getName().compareTo(right.getName());
293         }
294     }
295 
296     /**
297      * Compares two methods by its name.
298      */
299     private static class MethodComparator implements
300             Comparator<FrameworkMethod> {
301         public int compare(FrameworkMethod left, FrameworkMethod right) {
302             return NAME_ASCENDING.compare(left.getMethod(), right.getMethod());
303         }
304     }
305 }