1 package org.junit.runners.model;
2
3 import static java.lang.reflect.Modifier.isStatic;
4 import static org.junit.internal.MethodSorter.NAME_ASCENDING;
5
6 import java.lang.annotation.Annotation;
7 import java.lang.reflect.Constructor;
8 import java.lang.reflect.Field;
9 import java.lang.reflect.Method;
10 import java.lang.reflect.Modifier;
11 import java.util.ArrayList;
12 import java.util.Arrays;
13 import java.util.Collections;
14 import java.util.Comparator;
15 import java.util.LinkedHashMap;
16 import java.util.LinkedHashSet;
17 import java.util.List;
18 import java.util.Map;
19 import java.util.Set;
20
21 import org.junit.Assert;
22 import org.junit.Before;
23 import org.junit.BeforeClass;
24 import org.junit.internal.MethodSorter;
25
26
27
28
29
30
31 public class TestClass implements Annotatable {
32 private static final FieldComparator FIELD_COMPARATOR = new FieldComparator();
33 private static final MethodComparator METHOD_COMPARATOR = new MethodComparator();
34
35 private final Class<?> clazz;
36 private final Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations;
37 private final Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations;
38
39
40
41
42
43
44
45 public TestClass(Class<?> clazz) {
46 this.clazz = clazz;
47 if (clazz != null && clazz.getConstructors().length > 1) {
48 throw new IllegalArgumentException(
49 "Test class can only have one constructor");
50 }
51
52 Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations =
53 new LinkedHashMap<Class<? extends Annotation>, List<FrameworkMethod>>();
54 Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations =
55 new LinkedHashMap<Class<? extends Annotation>, List<FrameworkField>>();
56
57 scanAnnotatedMembers(methodsForAnnotations, fieldsForAnnotations);
58
59 this.methodsForAnnotations = makeDeeplyUnmodifiable(methodsForAnnotations);
60 this.fieldsForAnnotations = makeDeeplyUnmodifiable(fieldsForAnnotations);
61 }
62
63 protected void scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations) {
64 for (Class<?> eachClass : getSuperClasses(clazz)) {
65 for (Method eachMethod : MethodSorter.getDeclaredMethods(eachClass)) {
66 addToAnnotationLists(new FrameworkMethod(eachMethod), methodsForAnnotations);
67 }
68
69
70 for (Field eachField : getSortedDeclaredFields(eachClass)) {
71 addToAnnotationLists(new FrameworkField(eachField), fieldsForAnnotations);
72 }
73 }
74 }
75
76 private static Field[] getSortedDeclaredFields(Class<?> clazz) {
77 Field[] declaredFields = clazz.getDeclaredFields();
78 Arrays.sort(declaredFields, FIELD_COMPARATOR);
79 return declaredFields;
80 }
81
82 protected static <T extends FrameworkMember<T>> void addToAnnotationLists(T member,
83 Map<Class<? extends Annotation>, List<T>> map) {
84 for (Annotation each : member.getAnnotations()) {
85 Class<? extends Annotation> type = each.annotationType();
86 List<T> members = getAnnotatedMembers(map, type, true);
87 if (member.isShadowedBy(members)) {
88 return;
89 }
90 if (runsTopToBottom(type)) {
91 members.add(0, member);
92 } else {
93 members.add(member);
94 }
95 }
96 }
97
98 private static <T extends FrameworkMember<T>> Map<Class<? extends Annotation>, List<T>>
99 makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source) {
100 LinkedHashMap<Class<? extends Annotation>, List<T>> copy =
101 new LinkedHashMap<Class<? extends Annotation>, List<T>>();
102 for (Map.Entry<Class<? extends Annotation>, List<T>> entry : source.entrySet()) {
103 copy.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
104 }
105 return Collections.unmodifiableMap(copy);
106 }
107
108
109
110
111
112
113
114 public List<FrameworkMethod> getAnnotatedMethods() {
115 List<FrameworkMethod> methods = collectValues(methodsForAnnotations);
116 Collections.sort(methods, METHOD_COMPARATOR);
117 return methods;
118 }
119
120
121
122
123
124 public List<FrameworkMethod> getAnnotatedMethods(
125 Class<? extends Annotation> annotationClass) {
126 return Collections.unmodifiableList(getAnnotatedMembers(methodsForAnnotations, annotationClass, false));
127 }
128
129
130
131
132
133
134
135 public List<FrameworkField> getAnnotatedFields() {
136 return collectValues(fieldsForAnnotations);
137 }
138
139
140
141
142
143 public List<FrameworkField> getAnnotatedFields(
144 Class<? extends Annotation> annotationClass) {
145 return Collections.unmodifiableList(getAnnotatedMembers(fieldsForAnnotations, annotationClass, false));
146 }
147
148 private <T> List<T> collectValues(Map<?, List<T>> map) {
149 Set<T> values = new LinkedHashSet<T>();
150 for (List<T> additionalValues : map.values()) {
151 values.addAll(additionalValues);
152 }
153 return new ArrayList<T>(values);
154 }
155
156 private static <T> List<T> getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map,
157 Class<? extends Annotation> type, boolean fillIfAbsent) {
158 if (!map.containsKey(type) && fillIfAbsent) {
159 map.put(type, new ArrayList<T>());
160 }
161 List<T> members = map.get(type);
162 return members == null ? Collections.<T>emptyList() : members;
163 }
164
165 private static boolean runsTopToBottom(Class<? extends Annotation> annotation) {
166 return annotation.equals(Before.class)
167 || annotation.equals(BeforeClass.class);
168 }
169
170 private static List<Class<?>> getSuperClasses(Class<?> testClass) {
171 ArrayList<Class<?>> results = new ArrayList<Class<?>>();
172 Class<?> current = testClass;
173 while (current != null) {
174 results.add(current);
175 current = current.getSuperclass();
176 }
177 return results;
178 }
179
180
181
182
183 public Class<?> getJavaClass() {
184 return clazz;
185 }
186
187
188
189
190 public String getName() {
191 if (clazz == null) {
192 return "null";
193 }
194 return clazz.getName();
195 }
196
197
198
199
200
201
202 public Constructor<?> getOnlyConstructor() {
203 Constructor<?>[] constructors = clazz.getConstructors();
204 Assert.assertEquals(1, constructors.length);
205 return constructors[0];
206 }
207
208
209
210
211 public Annotation[] getAnnotations() {
212 if (clazz == null) {
213 return new Annotation[0];
214 }
215 return clazz.getAnnotations();
216 }
217
218 public <T extends Annotation> T getAnnotation(Class<T> annotationType) {
219 if (clazz == null) {
220 return null;
221 }
222 return clazz.getAnnotation(annotationType);
223 }
224
225 public <T> List<T> getAnnotatedFieldValues(Object test,
226 Class<? extends Annotation> annotationClass, Class<T> valueClass) {
227 List<T> results = new ArrayList<T>();
228 for (FrameworkField each : getAnnotatedFields(annotationClass)) {
229 try {
230 Object fieldValue = each.get(test);
231 if (valueClass.isInstance(fieldValue)) {
232 results.add(valueClass.cast(fieldValue));
233 }
234 } catch (IllegalAccessException e) {
235 throw new RuntimeException(
236 "How did getFields return a field we couldn't access?", e);
237 }
238 }
239 return results;
240 }
241
242 public <T> List<T> getAnnotatedMethodValues(Object test,
243 Class<? extends Annotation> annotationClass, Class<T> valueClass) {
244 List<T> results = new ArrayList<T>();
245 for (FrameworkMethod each : getAnnotatedMethods(annotationClass)) {
246 try {
247 Object fieldValue = each.invokeExplosively(test);
248 if (valueClass.isInstance(fieldValue)) {
249 results.add(valueClass.cast(fieldValue));
250 }
251 } catch (Throwable e) {
252 throw new RuntimeException(
253 "Exception in " + each.getName(), e);
254 }
255 }
256 return results;
257 }
258
259 public boolean isPublic() {
260 return Modifier.isPublic(clazz.getModifiers());
261 }
262
263 public boolean isANonStaticInnerClass() {
264 return clazz.isMemberClass() && !isStatic(clazz.getModifiers());
265 }
266
267 @Override
268 public int hashCode() {
269 return (clazz == null) ? 0 : clazz.hashCode();
270 }
271
272 @Override
273 public boolean equals(Object obj) {
274 if (this == obj) {
275 return true;
276 }
277 if (obj == null) {
278 return false;
279 }
280 if (getClass() != obj.getClass()) {
281 return false;
282 }
283 TestClass other = (TestClass) obj;
284 return clazz == other.clazz;
285 }
286
287
288
289
290 private static class FieldComparator implements Comparator<Field> {
291 public int compare(Field left, Field right) {
292 return left.getName().compareTo(right.getName());
293 }
294 }
295
296
297
298
299 private static class MethodComparator implements
300 Comparator<FrameworkMethod> {
301 public int compare(FrameworkMethod left, FrameworkMethod right) {
302 return NAME_ASCENDING.compare(left.getMethod(), right.getMethod());
303 }
304 }
305 }