001 package org.junit.runners.model;
002
003 import static java.lang.reflect.Modifier.isStatic;
004 import static org.junit.internal.MethodSorter.NAME_ASCENDING;
005
006 import java.lang.annotation.Annotation;
007 import java.lang.reflect.Constructor;
008 import java.lang.reflect.Field;
009 import java.lang.reflect.Method;
010 import java.lang.reflect.Modifier;
011 import java.util.ArrayList;
012 import java.util.Arrays;
013 import java.util.Collections;
014 import java.util.Comparator;
015 import java.util.LinkedHashMap;
016 import java.util.LinkedHashSet;
017 import java.util.List;
018 import java.util.Map;
019 import java.util.Set;
020
021 import org.junit.Assert;
022 import org.junit.Before;
023 import org.junit.BeforeClass;
024 import org.junit.internal.MethodSorter;
025
026 /**
027 * Wraps a class to be run, providing method validation and annotation searching
028 *
029 * @since 4.5
030 */
031 public class TestClass implements Annotatable {
032 private static final FieldComparator FIELD_COMPARATOR = new FieldComparator();
033 private static final MethodComparator METHOD_COMPARATOR = new MethodComparator();
034
035 private final Class<?> clazz;
036 private final Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations;
037 private final Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations;
038
039 /**
040 * Creates a {@code TestClass} wrapping {@code clazz}. Each time this
041 * constructor executes, the class is scanned for annotations, which can be
042 * an expensive process (we hope in future JDK's it will not be.) Therefore,
043 * try to share instances of {@code TestClass} where possible.
044 */
045 public TestClass(Class<?> clazz) {
046 this.clazz = clazz;
047 if (clazz != null && clazz.getConstructors().length > 1) {
048 throw new IllegalArgumentException(
049 "Test class can only have one constructor");
050 }
051
052 Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations =
053 new LinkedHashMap<Class<? extends Annotation>, List<FrameworkMethod>>();
054 Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations =
055 new LinkedHashMap<Class<? extends Annotation>, List<FrameworkField>>();
056
057 scanAnnotatedMembers(methodsForAnnotations, fieldsForAnnotations);
058
059 this.methodsForAnnotations = makeDeeplyUnmodifiable(methodsForAnnotations);
060 this.fieldsForAnnotations = makeDeeplyUnmodifiable(fieldsForAnnotations);
061 }
062
063 protected void scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations) {
064 for (Class<?> eachClass : getSuperClasses(clazz)) {
065 for (Method eachMethod : MethodSorter.getDeclaredMethods(eachClass)) {
066 addToAnnotationLists(new FrameworkMethod(eachMethod), methodsForAnnotations);
067 }
068 // ensuring fields are sorted to make sure that entries are inserted
069 // and read from fieldForAnnotations in a deterministic order
070 for (Field eachField : getSortedDeclaredFields(eachClass)) {
071 addToAnnotationLists(new FrameworkField(eachField), fieldsForAnnotations);
072 }
073 }
074 }
075
076 private static Field[] getSortedDeclaredFields(Class<?> clazz) {
077 Field[] declaredFields = clazz.getDeclaredFields();
078 Arrays.sort(declaredFields, FIELD_COMPARATOR);
079 return declaredFields;
080 }
081
082 protected static <T extends FrameworkMember<T>> void addToAnnotationLists(T member,
083 Map<Class<? extends Annotation>, List<T>> map) {
084 for (Annotation each : member.getAnnotations()) {
085 Class<? extends Annotation> type = each.annotationType();
086 List<T> members = getAnnotatedMembers(map, type, true);
087 if (member.isShadowedBy(members)) {
088 return;
089 }
090 if (runsTopToBottom(type)) {
091 members.add(0, member);
092 } else {
093 members.add(member);
094 }
095 }
096 }
097
098 private static <T extends FrameworkMember<T>> Map<Class<? extends Annotation>, List<T>>
099 makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source) {
100 LinkedHashMap<Class<? extends Annotation>, List<T>> copy =
101 new LinkedHashMap<Class<? extends Annotation>, List<T>>();
102 for (Map.Entry<Class<? extends Annotation>, List<T>> entry : source.entrySet()) {
103 copy.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
104 }
105 return Collections.unmodifiableMap(copy);
106 }
107
108 /**
109 * Returns, efficiently, all the non-overridden methods in this class and
110 * its superclasses that are annotated}.
111 *
112 * @since 4.12
113 */
114 public List<FrameworkMethod> getAnnotatedMethods() {
115 List<FrameworkMethod> methods = collectValues(methodsForAnnotations);
116 Collections.sort(methods, METHOD_COMPARATOR);
117 return methods;
118 }
119
120 /**
121 * Returns, efficiently, all the non-overridden methods in this class and
122 * its superclasses that are annotated with {@code annotationClass}.
123 */
124 public List<FrameworkMethod> getAnnotatedMethods(
125 Class<? extends Annotation> annotationClass) {
126 return Collections.unmodifiableList(getAnnotatedMembers(methodsForAnnotations, annotationClass, false));
127 }
128
129 /**
130 * Returns, efficiently, all the non-overridden fields in this class and its
131 * superclasses that are annotated.
132 *
133 * @since 4.12
134 */
135 public List<FrameworkField> getAnnotatedFields() {
136 return collectValues(fieldsForAnnotations);
137 }
138
139 /**
140 * Returns, efficiently, all the non-overridden fields in this class and its
141 * superclasses that are annotated with {@code annotationClass}.
142 */
143 public List<FrameworkField> getAnnotatedFields(
144 Class<? extends Annotation> annotationClass) {
145 return Collections.unmodifiableList(getAnnotatedMembers(fieldsForAnnotations, annotationClass, false));
146 }
147
148 private <T> List<T> collectValues(Map<?, List<T>> map) {
149 Set<T> values = new LinkedHashSet<T>();
150 for (List<T> additionalValues : map.values()) {
151 values.addAll(additionalValues);
152 }
153 return new ArrayList<T>(values);
154 }
155
156 private static <T> List<T> getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map,
157 Class<? extends Annotation> type, boolean fillIfAbsent) {
158 if (!map.containsKey(type) && fillIfAbsent) {
159 map.put(type, new ArrayList<T>());
160 }
161 List<T> members = map.get(type);
162 return members == null ? Collections.<T>emptyList() : members;
163 }
164
165 private static boolean runsTopToBottom(Class<? extends Annotation> annotation) {
166 return annotation.equals(Before.class)
167 || annotation.equals(BeforeClass.class);
168 }
169
170 private static List<Class<?>> getSuperClasses(Class<?> testClass) {
171 ArrayList<Class<?>> results = new ArrayList<Class<?>>();
172 Class<?> current = testClass;
173 while (current != null) {
174 results.add(current);
175 current = current.getSuperclass();
176 }
177 return results;
178 }
179
180 /**
181 * Returns the underlying Java class.
182 */
183 public Class<?> getJavaClass() {
184 return clazz;
185 }
186
187 /**
188 * Returns the class's name.
189 */
190 public String getName() {
191 if (clazz == null) {
192 return "null";
193 }
194 return clazz.getName();
195 }
196
197 /**
198 * Returns the only public constructor in the class, or throws an {@code
199 * AssertionError} if there are more or less than one.
200 */
201
202 public Constructor<?> getOnlyConstructor() {
203 Constructor<?>[] constructors = clazz.getConstructors();
204 Assert.assertEquals(1, constructors.length);
205 return constructors[0];
206 }
207
208 /**
209 * Returns the annotations on this class
210 */
211 public Annotation[] getAnnotations() {
212 if (clazz == null) {
213 return new Annotation[0];
214 }
215 return clazz.getAnnotations();
216 }
217
218 public <T extends Annotation> T getAnnotation(Class<T> annotationType) {
219 if (clazz == null) {
220 return null;
221 }
222 return clazz.getAnnotation(annotationType);
223 }
224
225 public <T> List<T> getAnnotatedFieldValues(Object test,
226 Class<? extends Annotation> annotationClass, Class<T> valueClass) {
227 List<T> results = new ArrayList<T>();
228 for (FrameworkField each : getAnnotatedFields(annotationClass)) {
229 try {
230 Object fieldValue = each.get(test);
231 if (valueClass.isInstance(fieldValue)) {
232 results.add(valueClass.cast(fieldValue));
233 }
234 } catch (IllegalAccessException e) {
235 throw new RuntimeException(
236 "How did getFields return a field we couldn't access?", e);
237 }
238 }
239 return results;
240 }
241
242 public <T> List<T> getAnnotatedMethodValues(Object test,
243 Class<? extends Annotation> annotationClass, Class<T> valueClass) {
244 List<T> results = new ArrayList<T>();
245 for (FrameworkMethod each : getAnnotatedMethods(annotationClass)) {
246 try {
247 /*
248 * A method annotated with @Rule may return a @TestRule or a @MethodRule,
249 * we cannot call the method to check whether the return type matches our
250 * expectation i.e. subclass of valueClass. If we do that then the method
251 * will be invoked twice and we do not want to do that. So we first check
252 * whether return type matches our expectation and only then call the method
253 * to fetch the MethodRule
254 */
255 if (valueClass.isAssignableFrom(each.getReturnType())) {
256 Object fieldValue = each.invokeExplosively(test);
257 results.add(valueClass.cast(fieldValue));
258 }
259 } catch (Throwable e) {
260 throw new RuntimeException(
261 "Exception in " + each.getName(), e);
262 }
263 }
264 return results;
265 }
266
267 public boolean isPublic() {
268 return Modifier.isPublic(clazz.getModifiers());
269 }
270
271 public boolean isANonStaticInnerClass() {
272 return clazz.isMemberClass() && !isStatic(clazz.getModifiers());
273 }
274
275 @Override
276 public int hashCode() {
277 return (clazz == null) ? 0 : clazz.hashCode();
278 }
279
280 @Override
281 public boolean equals(Object obj) {
282 if (this == obj) {
283 return true;
284 }
285 if (obj == null) {
286 return false;
287 }
288 if (getClass() != obj.getClass()) {
289 return false;
290 }
291 TestClass other = (TestClass) obj;
292 return clazz == other.clazz;
293 }
294
295 /**
296 * Compares two fields by its name.
297 */
298 private static class FieldComparator implements Comparator<Field> {
299 public int compare(Field left, Field right) {
300 return left.getName().compareTo(right.getName());
301 }
302 }
303
304 /**
305 * Compares two methods by its name.
306 */
307 private static class MethodComparator implements
308 Comparator<FrameworkMethod> {
309 public int compare(FrameworkMethod left, FrameworkMethod right) {
310 return NAME_ASCENDING.compare(left.getMethod(), right.getMethod());
311 }
312 }
313 }