Skip to content

Commit 870d0b7

Browse files
Merge pull request #288 from benjchristensen/issue282-PublishSubject
Fix PublishSubject non-deterministic behavior on concurrent modification
2 parents 85733b3 + e85c3d1 commit 870d0b7

File tree

1 file changed

+208
-9
lines changed

1 file changed

+208
-9
lines changed

rxjava-core/src/main/java/rx/subjects/PublishSubject.java

Lines changed: 208 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,30 @@
1515
*/
1616
package rx.subjects;
1717

18+
import static org.junit.Assert.*;
1819
import static org.mockito.Matchers.*;
1920
import static org.mockito.Mockito.*;
2021

2122
import java.util.ArrayList;
23+
import java.util.Collection;
2224
import java.util.List;
2325
import java.util.concurrent.ConcurrentHashMap;
2426
import java.util.concurrent.atomic.AtomicBoolean;
27+
import java.util.concurrent.atomic.AtomicInteger;
2528
import java.util.concurrent.atomic.AtomicReference;
2629

2730
import junit.framework.Assert;
2831

2932
import org.junit.Test;
33+
import org.mockito.InOrder;
3034
import org.mockito.Mockito;
3135

3236
import rx.Notification;
3337
import rx.Observable;
3438
import rx.Observer;
3539
import rx.Subscription;
3640
import rx.operators.AtomicObservableSubscription;
41+
import rx.subscriptions.Subscriptions;
3742
import rx.util.functions.Action1;
3843
import rx.util.functions.Func0;
3944
import rx.util.functions.Func1;
@@ -62,10 +67,15 @@
6267
public class PublishSubject<T> extends Subject<T, T> {
6368
public static <T> PublishSubject<T> create() {
6469
final ConcurrentHashMap<Subscription, Observer<T>> observers = new ConcurrentHashMap<Subscription, Observer<T>>();
65-
70+
final AtomicReference<Notification<T>> terminalState = new AtomicReference<Notification<T>>();
71+
6672
Func1<Observer<T>, Subscription> onSubscribe = new Func1<Observer<T>, Subscription>() {
6773
@Override
6874
public Subscription call(Observer<T> observer) {
75+
// shortcut check if terminal state exists already
76+
Subscription s = checkTerminalState(observer);
77+
if(s != null) return s;
78+
6979
final AtomicObservableSubscription subscription = new AtomicObservableSubscription();
7080

7181
subscription.wrap(new Subscription() {
@@ -76,43 +86,110 @@ public void unsubscribe() {
7686
}
7787
});
7888

79-
// on subscribe add it to the map of outbound observers to notify
80-
observers.put(subscription, observer);
81-
return subscription;
89+
/**
90+
* NOTE: We are synchronizing to avoid a race condition between terminalState being set and
91+
* a new observer being added to observers.
92+
*
93+
* The synchronization only occurs on subscription and terminal states, it does not affect onNext calls
94+
* so a high-volume hot-observable will not pay this cost for emitting data.
95+
*
96+
* Due to the restricted impact of blocking synchronization here I have not pursued more complicated
97+
* approaches to try and stay completely non-blocking.
98+
*/
99+
synchronized (terminalState) {
100+
// check terminal state again
101+
s = checkTerminalState(observer);
102+
if (s != null)
103+
return s;
104+
105+
// on subscribe add it to the map of outbound observers to notify
106+
observers.put(subscription, observer);
107+
108+
return subscription;
109+
}
110+
}
111+
112+
private Subscription checkTerminalState(Observer<T> observer) {
113+
Notification<T> n = terminalState.get();
114+
if (n != null) {
115+
// we are terminated to immediately emit and don't continue with subscription
116+
if (n.isOnCompleted()) {
117+
observer.onCompleted();
118+
} else {
119+
observer.onError(n.getException());
120+
}
121+
return Subscriptions.empty();
122+
} else {
123+
return null;
124+
}
82125
}
83126
};
84127

85-
return new PublishSubject<T>(onSubscribe, observers);
128+
return new PublishSubject<T>(onSubscribe, observers, terminalState);
86129
}
87130

88131
private final ConcurrentHashMap<Subscription, Observer<T>> observers;
132+
private final AtomicReference<Notification<T>> terminalState;
89133

90-
protected PublishSubject(Func1<Observer<T>, Subscription> onSubscribe, ConcurrentHashMap<Subscription, Observer<T>> observers) {
134+
protected PublishSubject(Func1<Observer<T>, Subscription> onSubscribe, ConcurrentHashMap<Subscription, Observer<T>> observers, AtomicReference<Notification<T>> terminalState) {
91135
super(onSubscribe);
92136
this.observers = observers;
137+
this.terminalState = terminalState;
93138
}
94139

95140
@Override
96141
public void onCompleted() {
97-
for (Observer<T> observer : observers.values()) {
142+
/**
143+
* Synchronizing despite terminalState being an AtomicReference because of multi-step logic in subscription.
144+
* Why use AtomicReference then? Convenient for passing around a mutable reference holder between the
145+
* onSubscribe function and PublishSubject instance... and it's a "better volatile" for the shortcut codepath.
146+
*/
147+
synchronized (terminalState) {
148+
terminalState.set(new Notification<T>());
149+
}
150+
for (Observer<T> observer : snapshotOfValues()) {
98151
observer.onCompleted();
99152
}
153+
observers.clear();
100154
}
101155

102156
@Override
103157
public void onError(Exception e) {
104-
for (Observer<T> observer : observers.values()) {
158+
/**
159+
* Synchronizing despite terminalState being an AtomicReference because of multi-step logic in subscription.
160+
* Why use AtomicReference then? Convenient for passing around a mutable reference holder between the
161+
* onSubscribe function and PublishSubject instance... and it's a "better volatile" for the shortcut codepath.
162+
*/
163+
synchronized (terminalState) {
164+
terminalState.set(new Notification<T>(e));
165+
}
166+
for (Observer<T> observer : snapshotOfValues()) {
105167
observer.onError(e);
106168
}
169+
observers.clear();
107170
}
108171

109172
@Override
110173
public void onNext(T args) {
111-
for (Observer<T> observer : observers.values()) {
174+
for (Observer<T> observer : snapshotOfValues()) {
112175
observer.onNext(args);
113176
}
114177
}
115178

179+
/**
180+
* Current snapshot of 'values()' so that concurrent modifications aren't included.
181+
*
182+
* This makes it behave deterministically in a single-threaded execution when nesting subscribes.
183+
*
184+
* In multi-threaded execution it will cause new subscriptions to wait until the following onNext instead
185+
* of possibly being included in the current onNext iteration.
186+
*
187+
* @return List<Observer<T>>
188+
*/
189+
private Collection<Observer<T>> snapshotOfValues() {
190+
return new ArrayList<Observer<T>>(observers.values());
191+
}
192+
116193
public static class UnitTest {
117194
@Test
118195
public void test() {
@@ -307,6 +384,75 @@ private void assertObservedUntilTwo(Observer<String> aObserver)
307384
verify(aObserver, Mockito.never()).onCompleted();
308385
}
309386

387+
/**
388+
* Test that subscribing after onError/onCompleted immediately terminates instead of causing it to hang.
389+
*
390+
* Nothing is mentioned in Rx Guidelines for what to do in this case so I'm doing what seems to make sense
391+
* which is:
392+
*
393+
* - cache terminal state (onError/onCompleted)
394+
* - any subsequent subscriptions will immediately receive the terminal state rather than start a new subscription
395+
*
396+
*/
397+
@Test
398+
public void testUnsubscribeAfterOnCompleted() {
399+
PublishSubject<Object> subject = PublishSubject.create();
400+
401+
@SuppressWarnings("unchecked")
402+
Observer<String> anObserver = mock(Observer.class);
403+
subject.subscribe(anObserver);
404+
405+
subject.onNext("one");
406+
subject.onNext("two");
407+
subject.onCompleted();
408+
409+
InOrder inOrder = inOrder(anObserver);
410+
inOrder.verify(anObserver, times(1)).onNext("one");
411+
inOrder.verify(anObserver, times(1)).onNext("two");
412+
inOrder.verify(anObserver, times(1)).onCompleted();
413+
inOrder.verify(anObserver, Mockito.never()).onError(any(Exception.class));
414+
415+
@SuppressWarnings("unchecked")
416+
Observer<String> anotherObserver = mock(Observer.class);
417+
subject.subscribe(anotherObserver);
418+
419+
inOrder = inOrder(anotherObserver);
420+
inOrder.verify(anotherObserver, Mockito.never()).onNext("one");
421+
inOrder.verify(anotherObserver, Mockito.never()).onNext("two");
422+
inOrder.verify(anotherObserver, times(1)).onCompleted();
423+
inOrder.verify(anotherObserver, Mockito.never()).onError(any(Exception.class));
424+
}
425+
426+
@Test
427+
public void testUnsubscribeAfterOnError() {
428+
PublishSubject<Object> subject = PublishSubject.create();
429+
RuntimeException exception = new RuntimeException("failure");
430+
431+
@SuppressWarnings("unchecked")
432+
Observer<String> anObserver = mock(Observer.class);
433+
subject.subscribe(anObserver);
434+
435+
subject.onNext("one");
436+
subject.onNext("two");
437+
subject.onError(exception);
438+
439+
InOrder inOrder = inOrder(anObserver);
440+
inOrder.verify(anObserver, times(1)).onNext("one");
441+
inOrder.verify(anObserver, times(1)).onNext("two");
442+
inOrder.verify(anObserver, times(1)).onError(exception);
443+
inOrder.verify(anObserver, Mockito.never()).onCompleted();
444+
445+
@SuppressWarnings("unchecked")
446+
Observer<String> anotherObserver = mock(Observer.class);
447+
subject.subscribe(anotherObserver);
448+
449+
inOrder = inOrder(anotherObserver);
450+
inOrder.verify(anotherObserver, Mockito.never()).onNext("one");
451+
inOrder.verify(anotherObserver, Mockito.never()).onNext("two");
452+
inOrder.verify(anotherObserver, times(1)).onError(exception);
453+
inOrder.verify(anotherObserver, Mockito.never()).onCompleted();
454+
}
455+
310456
@Test
311457
public void testUnsubscribe()
312458
{
@@ -340,5 +486,58 @@ public void call(PublishSubject<Object> DefaultSubject)
340486
}
341487
});
342488
}
489+
490+
@Test
491+
public void testNestedSubscribe() {
492+
final PublishSubject<Integer> s = PublishSubject.create();
493+
494+
final AtomicInteger countParent = new AtomicInteger();
495+
final AtomicInteger countChildren = new AtomicInteger();
496+
final AtomicInteger countTotal = new AtomicInteger();
497+
498+
final ArrayList<String> list = new ArrayList<String>();
499+
500+
s.mapMany(new Func1<Integer, Observable<String>>() {
501+
502+
@Override
503+
public Observable<String> call(final Integer v) {
504+
countParent.incrementAndGet();
505+
506+
// then subscribe to subject again (it will not receive the previous value)
507+
return s.map(new Func1<Integer, String>() {
508+
509+
@Override
510+
public String call(Integer v2) {
511+
countChildren.incrementAndGet();
512+
return "Parent: " + v + " Child: " + v2;
513+
}
514+
515+
});
516+
}
517+
518+
}).subscribe(new Action1<String>() {
519+
520+
@Override
521+
public void call(String v) {
522+
countTotal.incrementAndGet();
523+
list.add(v);
524+
}
525+
526+
});
527+
528+
529+
for(int i=0; i<10; i++) {
530+
s.onNext(i);
531+
}
532+
s.onCompleted();
533+
534+
// System.out.println("countParent: " + countParent.get());
535+
// System.out.println("countChildren: " + countChildren.get());
536+
// System.out.println("countTotal: " + countTotal.get());
537+
538+
// 9+8+7+6+5+4+3+2+1+0 == 45
539+
assertEquals(45, list.size());
540+
}
541+
343542
}
344543
}

0 commit comments

Comments
 (0)