Skip to content

Commit ad2358b

Browse files
GroupBy GroupedObservables should not re-subscribe to parent sequence
#282 Refactored to maintain a single subscription that propagates events to the correct child GroupedObservables.
1 parent 870d0b7 commit ad2358b

File tree

1 file changed

+237
-63
lines changed

1 file changed

+237
-63
lines changed

rxjava-core/src/main/java/rx/operators/OperationGroupBy.java

Lines changed: 237 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,24 @@
1717

1818
import static org.junit.Assert.*;
1919

20-
import java.util.ArrayList;
2120
import java.util.Arrays;
22-
import java.util.HashMap;
23-
import java.util.List;
21+
import java.util.Collection;
2422
import java.util.Map;
2523
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.concurrent.ConcurrentLinkedQueue;
25+
import java.util.concurrent.CountDownLatch;
26+
import java.util.concurrent.TimeUnit;
27+
import java.util.concurrent.atomic.AtomicInteger;
28+
import java.util.concurrent.atomic.AtomicReference;
2629

2730
import org.junit.Test;
2831

2932
import rx.Observable;
3033
import rx.Observer;
3134
import rx.Subscription;
3235
import rx.observables.GroupedObservable;
36+
import rx.subscriptions.Subscriptions;
37+
import rx.util.functions.Action1;
3338
import rx.util.functions.Func1;
3439
import rx.util.functions.Functions;
3540

@@ -55,69 +60,137 @@ public static <K, T> Func1<Observer<GroupedObservable<K, T>>, Subscription> grou
5560
}
5661

5762
private static class GroupBy<K, V> implements Func1<Observer<GroupedObservable<K, V>>, Subscription> {
63+
5864
private final Observable<KeyValue<K, V>> source;
65+
private final ConcurrentHashMap<K, GroupedSubject<K, V>> groupedObservables = new ConcurrentHashMap<K, GroupedSubject<K, V>>();
5966

6067
private GroupBy(Observable<KeyValue<K, V>> source) {
6168
this.source = source;
6269
}
6370

6471
@Override
6572
public Subscription call(final Observer<GroupedObservable<K, V>> observer) {
66-
return source.subscribe(new GroupByObserver(observer));
73+
return source.subscribe(new Observer<KeyValue<K, V>>() {
74+
75+
@Override
76+
public void onCompleted() {
77+
// we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging
78+
for (GroupedSubject<K, V> o : groupedObservables.values()) {
79+
o.onCompleted();
80+
}
81+
// now the parent
82+
observer.onCompleted();
83+
}
84+
85+
@Override
86+
public void onError(Exception e) {
87+
// we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging
88+
for (GroupedSubject<K, V> o : groupedObservables.values()) {
89+
o.onError(e);
90+
}
91+
// now the parent
92+
observer.onError(e);
93+
}
94+
95+
@Override
96+
public void onNext(KeyValue<K, V> value) {
97+
GroupedSubject<K, V> gs = groupedObservables.get(value.key);
98+
if (gs == null) {
99+
/*
100+
* Technically the source should be single-threaded so we shouldn't need to do this but I am
101+
* programming defensively as most operators are so this can work with a concurrent sequence
102+
* if it ends up receiving one.
103+
*/
104+
GroupedSubject<K, V> newGs = GroupedSubject.<K, V> create(value.key);
105+
GroupedSubject<K, V> existing = groupedObservables.putIfAbsent(value.key, newGs);
106+
if (existing == null) {
107+
// we won so use the one we created
108+
gs = newGs;
109+
// since we won the creation we emit this new GroupedObservable
110+
observer.onNext(gs);
111+
} else {
112+
// another thread beat us so use the existing one
113+
gs = existing;
114+
}
115+
}
116+
gs.onNext(value.value);
117+
}
118+
});
67119
}
120+
}
68121

69-
private class GroupByObserver implements Observer<KeyValue<K, V>> {
70-
private final Observer<GroupedObservable<K, V>> underlying;
122+
private static class GroupedSubject<K, T> extends GroupedObservable<K, T> implements Observer<T> {
71123

72-
private final ConcurrentHashMap<K, Boolean> keys = new ConcurrentHashMap<K, Boolean>();
124+
static <K, T> GroupedSubject<K, T> create(K key) {
125+
@SuppressWarnings("unchecked")
126+
final AtomicReference<Observer<T>> subscribedObserver = new AtomicReference<Observer<T>>(EMPTY_OBSERVER);
73127

74-
private GroupByObserver(Observer<GroupedObservable<K, V>> underlying) {
75-
this.underlying = underlying;
76-
}
128+
return new GroupedSubject<K, T>(key, new Func1<Observer<T>, Subscription>() {
77129

78-
@Override
79-
public void onCompleted() {
80-
underlying.onCompleted();
81-
}
130+
@Override
131+
public Subscription call(Observer<T> observer) {
132+
// register Observer
133+
subscribedObserver.set(observer);
82134

83-
@Override
84-
public void onError(Exception e) {
85-
underlying.onError(e);
86-
}
135+
return new Subscription() {
87136

88-
@Override
89-
public void onNext(final KeyValue<K, V> args) {
90-
K key = args.key;
91-
boolean newGroup = keys.putIfAbsent(key, true) == null;
92-
if (newGroup) {
93-
underlying.onNext(buildObservableFor(source, key));
137+
@SuppressWarnings("unchecked")
138+
@Override
139+
public void unsubscribe() {
140+
// we remove the Observer so we stop emitting further events (they will be ignored if parent continues to send)
141+
subscribedObserver.set(EMPTY_OBSERVER);
142+
// I don't believe we need to worry about the parent here as it's a separate sequence that would
143+
// be unsubscribed to directly if that needs to happen.
144+
}
145+
};
94146
}
95-
}
147+
}, subscribedObserver);
96148
}
97-
}
98149

99-
private static <K, R> GroupedObservable<K, R> buildObservableFor(Observable<KeyValue<K, R>> source, final K key) {
100-
final Observable<R> observable = source.filter(new Func1<KeyValue<K, R>, Boolean>() {
101-
@Override
102-
public Boolean call(KeyValue<K, R> pair) {
103-
return key.equals(pair.key);
104-
}
105-
}).map(new Func1<KeyValue<K, R>, R>() {
106-
@Override
107-
public R call(KeyValue<K, R> pair) {
108-
return pair.value;
109-
}
110-
});
111-
return new GroupedObservable<K, R>(key, new Func1<Observer<R>, Subscription>() {
150+
private final AtomicReference<Observer<T>> subscribedObserver;
112151

113-
@Override
114-
public Subscription call(Observer<R> observer) {
115-
return observable.subscribe(observer);
116-
}
152+
public GroupedSubject(K key, Func1<Observer<T>, Subscription> onSubscribe, AtomicReference<Observer<T>> subscribedObserver) {
153+
super(key, onSubscribe);
154+
this.subscribedObserver = subscribedObserver;
155+
}
156+
157+
@Override
158+
public void onCompleted() {
159+
subscribedObserver.get().onCompleted();
160+
}
161+
162+
@Override
163+
public void onError(Exception e) {
164+
subscribedObserver.get().onError(e);
165+
}
166+
167+
@Override
168+
public void onNext(T v) {
169+
subscribedObserver.get().onNext(v);
170+
}
117171

118-
});
119172
}
120173

174+
@SuppressWarnings("rawtypes")
175+
private static Observer EMPTY_OBSERVER = new Observer() {
176+
177+
@Override
178+
public void onCompleted() {
179+
// do nothing
180+
}
181+
182+
@Override
183+
public void onError(Exception e) {
184+
// do nothing
185+
}
186+
187+
@Override
188+
public void onNext(Object args) {
189+
// do nothing
190+
}
191+
192+
};
193+
121194
private static class KeyValue<K, V> {
122195
private final K key;
123196
private final V value;
@@ -141,45 +214,146 @@ public void testGroupBy() {
141214
Observable<String> source = Observable.from("one", "two", "three", "four", "five", "six");
142215
Observable<GroupedObservable<Integer, String>> grouped = Observable.create(groupBy(source, length));
143216

144-
Map<Integer, List<String>> map = toMap(grouped);
217+
Map<Integer, Collection<String>> map = toMap(grouped);
145218

146219
assertEquals(3, map.size());
147-
assertEquals(Arrays.asList("one", "two", "six"), map.get(3));
148-
assertEquals(Arrays.asList("four", "five"), map.get(4));
149-
assertEquals(Arrays.asList("three"), map.get(5));
150-
220+
assertArrayEquals(Arrays.asList("one", "two", "six").toArray(), map.get(3).toArray());
221+
assertArrayEquals(Arrays.asList("four", "five").toArray(), map.get(4).toArray());
222+
assertArrayEquals(Arrays.asList("three").toArray(), map.get(5).toArray());
151223
}
152224

153225
@Test
154226
public void testEmpty() {
155227
Observable<String> source = Observable.from();
156228
Observable<GroupedObservable<Integer, String>> grouped = Observable.create(groupBy(source, length));
157229

158-
Map<Integer, List<String>> map = toMap(grouped);
230+
Map<Integer, Collection<String>> map = toMap(grouped);
159231

160232
assertTrue(map.isEmpty());
161233
}
162234

163-
private static <K, V> Map<K, List<V>> toMap(Observable<GroupedObservable<K, V>> observable) {
164-
Map<K, List<V>> result = new HashMap<K, List<V>>();
165-
for (GroupedObservable<K, V> g : observable.toBlockingObservable().toIterable()) {
166-
K key = g.getKey();
235+
private static <K, V> Map<K, Collection<V>> toMap(Observable<GroupedObservable<K, V>> observable) {
167236

168-
for (V value : g.toBlockingObservable().toIterable()) {
169-
List<V> values = result.get(key);
170-
if (values == null) {
171-
values = new ArrayList<V>();
172-
result.put(key, values);
173-
}
237+
final ConcurrentHashMap<K, Collection<V>> result = new ConcurrentHashMap<K, Collection<V>>();
174238

175-
values.add(value);
176-
}
239+
observable.forEach(new Action1<GroupedObservable<K, V>>() {
177240

178-
}
241+
@Override
242+
public void call(final GroupedObservable<K, V> o) {
243+
result.put(o.getKey(), new ConcurrentLinkedQueue<V>());
244+
o.subscribe(new Action1<V>() {
245+
246+
@Override
247+
public void call(V v) {
248+
result.get(o.getKey()).add(v);
249+
}
250+
251+
});
252+
}
253+
});
179254

180255
return result;
181256
}
182257

258+
/**
259+
* Assert that only a single subscription to a stream occurs and that all events are received.
260+
*
261+
* @throws Exception
262+
*/
263+
@Test
264+
public void testGroupedEventStream() throws Exception {
265+
266+
final AtomicInteger eventCounter = new AtomicInteger();
267+
final AtomicInteger subscribeCounter = new AtomicInteger();
268+
final AtomicInteger groupCounter = new AtomicInteger();
269+
final CountDownLatch latch = new CountDownLatch(1);
270+
final int count = 100;
271+
final int groupCount = 2;
272+
273+
Observable<Event> es = Observable.create(new Func1<Observer<Event>, Subscription>() {
274+
275+
@Override
276+
public Subscription call(final Observer<Event> observer) {
277+
System.out.println("*** Subscribing to EventStream ***");
278+
subscribeCounter.incrementAndGet();
279+
new Thread(new Runnable() {
280+
281+
@Override
282+
public void run() {
283+
for (int i = 0; i < count; i++) {
284+
Event e = new Event();
285+
e.source = i % groupCount;
286+
e.message = "Event-" + i;
287+
observer.onNext(e);
288+
}
289+
observer.onCompleted();
290+
}
291+
292+
}).start();
293+
return Subscriptions.empty();
294+
}
295+
296+
});
297+
298+
es.groupBy(new Func1<Event, Integer>() {
299+
300+
@Override
301+
public Integer call(Event e) {
302+
return e.source;
303+
}
304+
}).mapMany(new Func1<GroupedObservable<Integer, Event>, Observable<String>>() {
305+
306+
@Override
307+
public Observable<String> call(GroupedObservable<Integer, Event> eventGroupedObservable) {
308+
System.out.println("GroupedObservable Key: " + eventGroupedObservable.getKey());
309+
groupCounter.incrementAndGet();
310+
311+
return eventGroupedObservable.map(new Func1<Event, String>() {
312+
313+
@Override
314+
public String call(Event event) {
315+
return "Source: " + event.source + " Message: " + event.message;
316+
}
317+
});
318+
319+
};
320+
}).subscribe(new Observer<String>() {
321+
322+
@Override
323+
public void onCompleted() {
324+
latch.countDown();
325+
}
326+
327+
@Override
328+
public void onError(Exception e) {
329+
e.printStackTrace();
330+
latch.countDown();
331+
}
332+
333+
@Override
334+
public void onNext(String outputMessage) {
335+
System.out.println(outputMessage);
336+
eventCounter.incrementAndGet();
337+
}
338+
});
339+
340+
latch.await(5000, TimeUnit.MILLISECONDS);
341+
assertEquals(1, subscribeCounter.get());
342+
assertEquals(groupCount, groupCounter.get());
343+
assertEquals(count, eventCounter.get());
344+
345+
}
346+
347+
private static class Event {
348+
int source;
349+
String message;
350+
351+
@Override
352+
public String toString() {
353+
return "Event => source: " + source + " message: " + message;
354+
}
355+
}
356+
183357
}
184358

185359
}

0 commit comments

Comments
 (0)