diff --git a/rxjava-core/src/main/java/rx/Observable.java b/rxjava-core/src/main/java/rx/Observable.java index 0cef17b1d3..4052965253 100644 --- a/rxjava-core/src/main/java/rx/Observable.java +++ b/rxjava-core/src/main/java/rx/Observable.java @@ -54,6 +54,7 @@ import rx.operators.OperationFinally; import rx.operators.OperationFirstOrDefault; import rx.operators.OperationGroupBy; +import rx.operators.OperationGroupJoin; import rx.operators.OperationInterval; import rx.operators.OperationJoin; import rx.operators.OperationJoinPatterns; @@ -6118,5 +6119,26 @@ public Observable>> toMultimap(Func1 Observable>> toMultimap(Func1 keySelector, Func1 valueSelector, Func0>> mapFactory, Func1> collectionFactory) { return create(OperationToMultimap.toMultimap(this, keySelector, valueSelector, mapFactory, collectionFactory)); - } + } + + /** + * Return an Observable which correlates two sequences when they overlap and groups the results. + * + * @param right the other Observable to correlate values of this observable to + * @param leftDuration function that returns an Observable which indicates the duration of + * the values of this Observable + * @param rightDuration function that returns an Observable which indicates the duration of + * the values of the right Observable + * @param resultSelector function that takes a left value, the right observable and returns the + * value to be emitted + * @return an Observable that emits grouped values based on overlapping durations from this and + * another Observable + * + * @see MSDN: Observable.GroupJoin + */ + public Observable groupJoin(Observable right, Func1> leftDuration, + Func1> rightDuration, + Func2, ? extends R> resultSelector) { + return create(new OperationGroupJoin(this, right, leftDuration, rightDuration, resultSelector)); + } } diff --git a/rxjava-core/src/main/java/rx/operators/OperationGroupJoin.java b/rxjava-core/src/main/java/rx/operators/OperationGroupJoin.java new file mode 100644 index 0000000000..fe3fb840ae --- /dev/null +++ b/rxjava-core/src/main/java/rx/operators/OperationGroupJoin.java @@ -0,0 +1,333 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.operators; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import rx.Observable; +import rx.Observable.OnSubscribeFunc; +import rx.Observer; +import rx.Subscription; +import rx.subjects.PublishSubject; +import rx.subjects.Subject; +import rx.subscriptions.CompositeSubscription; +import rx.subscriptions.RefCountSubscription; +import rx.subscriptions.SerialSubscription; +import rx.util.functions.Func1; +import rx.util.functions.Func2; + +/** + * Corrrelates two sequences when they overlap and groups the results. + * + * @see MSDN: Observable.GroupJoin + */ +public class OperationGroupJoin implements OnSubscribeFunc { + protected final Observable left; + protected final Observable right; + protected final Func1> leftDuration; + protected final Func1> rightDuration; + protected final Func2, ? extends R> resultSelector; + public OperationGroupJoin( + Observable left, + Observable right, + Func1> leftDuration, + Func1> rightDuration, + Func2, ? extends R> resultSelector + ) { + this.left = left; + this.right = right; + this.leftDuration = leftDuration; + this.rightDuration = rightDuration; + this.resultSelector = resultSelector; + } + @Override + public Subscription onSubscribe(Observer t1) { + ResultManager ro = new ResultManager(t1); + ro.init(); + return ro; + } + /** Manages sub-observers and subscriptions. */ + class ResultManager implements Subscription { + final RefCountSubscription cancel; + final Observer observer; + final CompositeSubscription group; + final Object guard = new Object(); + int leftIds; + int rightIds; + final Map> leftMap = new HashMap>(); + final Map rightMap = new HashMap(); + boolean leftDone; + boolean rightDone; + public ResultManager(Observer observer) { + this.observer = observer; + this.group = new CompositeSubscription(); + this.cancel = new RefCountSubscription(group); + } + public void init() { + SerialSubscription s1 = new SerialSubscription(); + SerialSubscription s2 = new SerialSubscription(); + + group.add(s1); + group.add(s2); + + s1.setSubscription(left.subscribe(new LeftObserver(s1))); + s2.setSubscription(right.subscribe(new RightObserver(s2))); + + } + + @Override + public void unsubscribe() { + cancel.unsubscribe(); + } + void groupsOnCompleted() { + List> list = new ArrayList>(leftMap.values()); + leftMap.clear(); + rightMap.clear(); + for (Observer o : list) { + o.onCompleted(); + } + } + /** Observe the left source. */ + class LeftObserver implements Observer { + final Subscription tosource; + public LeftObserver(Subscription tosource) { + this.tosource = tosource; + } + @Override + public void onNext(T1 args) { + try { + int id; + Subject subj = PublishSubject.create(); + synchronized (guard) { + id = leftIds++; + leftMap.put(id, subj); + } + + Observable window = Observable.create(new WindowObservableFunc(subj, cancel)); + + Observable duration = leftDuration.call(args); + + SerialSubscription sduration = new SerialSubscription(); + group.add(sduration); + sduration.setSubscription(duration.subscribe(new LeftDurationObserver(id, sduration, subj))); + + R result = resultSelector.call(args, window); + + synchronized (guard) { + observer.onNext(result); + for (T2 t2 : rightMap.values()) { + subj.onNext(t2); + + } + } + } catch (Throwable t) { + onError(t); + } + } + + @Override + public void onCompleted() { + synchronized (guard) { + leftDone = true; + if (rightDone) { + groupsOnCompleted(); + observer.onCompleted(); + cancel.unsubscribe(); + } + } + } + + @Override + public void onError(Throwable e) { + synchronized (guard) { + for (Observer o : leftMap.values()) { + o.onError(e); + } + observer.onError(e); + cancel.unsubscribe(); + } + } + + + } + /** Observe the right source. */ + class RightObserver implements Observer { + final Subscription tosource; + public RightObserver(Subscription tosource) { + this.tosource = tosource; + } + @Override + public void onNext(T2 args) { + try { + int id; + synchronized (guard) { + id = rightIds++; + rightMap.put(id, args); + } + Observable duration = rightDuration.call(args); + + SerialSubscription sduration = new SerialSubscription(); + group.add(sduration); + sduration.setSubscription(duration.subscribe(new RightDurationObserver(id, sduration))); + + synchronized (guard) { + for (Observer o : leftMap.values()) { + o.onNext(args); + } + } + } catch (Throwable t) { + onError(t); + } + } + + @Override + public void onCompleted() { +// tosource.unsubscribe(); + synchronized (guard) { + rightDone = true; + if (leftDone) { + groupsOnCompleted(); + observer.onCompleted(); + cancel.unsubscribe(); + } + } + } + + @Override + public void onError(Throwable e) { + synchronized (guard) { + for (Observer o : leftMap.values()) { + o.onError(e); + } + + observer.onError(e); + cancel.unsubscribe(); + } + } + } + /** Observe left duration and apply termination. */ + class LeftDurationObserver implements Observer { + final int id; + final Subscription sduration; + final Observer gr; + public LeftDurationObserver(int id, Subscription sduration, Observer gr) { + this.id = id; + this.sduration = sduration; + this.gr = gr; + } + + @Override + public void onCompleted() { + synchronized (guard) { + if (leftMap.remove(id) != null) { + gr.onCompleted(); + } + } + group.remove(sduration); + } + + @Override + public void onError(Throwable e) { + synchronized (guard) { + observer.onError(e); + } + cancel.unsubscribe(); + } + + @Override + public void onNext(D1 args) { + onCompleted(); + } + } + /** Observe right duration and apply termination. */ + class RightDurationObserver implements Observer { + final int id; + final Subscription sduration; + public RightDurationObserver(int id, Subscription sduration) { + this.id = id; + this.sduration = sduration; + } + + @Override + public void onCompleted() { + synchronized (guard) { + rightMap.remove(id); + } + group.remove(sduration); + } + + @Override + public void onError(Throwable e) { + synchronized (guard) { + observer.onError(e); + } + cancel.unsubscribe(); + } + + @Override + public void onNext(D2 args) { + onCompleted(); + } + } + } + /** + * The reference-counted window observable. + * Subscribes to the underlying Observable by using a reference-counted + * subscription. + */ + static class WindowObservableFunc implements OnSubscribeFunc { + final RefCountSubscription refCount; + final Observable underlying; + public WindowObservableFunc(Observable underlying, RefCountSubscription refCount) { + this.refCount = refCount; + this.underlying = underlying; + } + + @Override + public Subscription onSubscribe(Observer t1) { + CompositeSubscription cs = new CompositeSubscription(); + cs.add(refCount.getSubscription()); + WindowObserver wo = new WindowObserver(t1, cs); + cs.add(underlying.subscribe(wo)); + return cs; + } + /** Observe activities on the window. */ + class WindowObserver implements Observer { + final Observer observer; + final Subscription self; + public WindowObserver(Observer observer, Subscription self) { + this.observer = observer; + this.self = self; + } + @Override + public void onNext(T args) { + observer.onNext(args); + } + @Override + public void onError(Throwable e) { + observer.onError(e); + self.unsubscribe(); + } + @Override + public void onCompleted() { + observer.onCompleted(); + self.unsubscribe(); + } + } + } +} diff --git a/rxjava-core/src/main/java/rx/subscriptions/RefCountSubscription.java b/rxjava-core/src/main/java/rx/subscriptions/RefCountSubscription.java new file mode 100644 index 0000000000..3044e1c001 --- /dev/null +++ b/rxjava-core/src/main/java/rx/subscriptions/RefCountSubscription.java @@ -0,0 +1,101 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.subscriptions; + +import java.util.concurrent.atomic.AtomicBoolean; +import rx.Subscription; + +/** + * Keeps track of the sub-subscriptions and unsubscribes the underlying + * subscription once all sub-subscriptions have unsubscribed. + * + * @see MSDN RefCountDisposable + */ +public class RefCountSubscription implements Subscription { + private final Object guard = new Object(); + private Subscription main; + private boolean done; + private int count; + public RefCountSubscription(Subscription s) { + if (s == null) { + throw new IllegalArgumentException("s"); + } + this.main = s; + } + /** + * Returns a new sub-subscription. + */ + public Subscription getSubscription() { + synchronized (guard) { + if (main == null) { + return Subscriptions.empty(); + } else { + count++; + return new InnerSubscription(); + } + } + } + /** + * Check if this subscription is already unsubscribed. + */ + public boolean isUnsubscribed() { + synchronized (guard) { + return main == null; + } + } + @Override + public void unsubscribe() { + Subscription s = null; + synchronized (guard) { + if (main != null && !done) { + done = true; + if (count == 0) { + s = main; + main = null; + } + } + } + if (s != null) { + s.unsubscribe(); + } + } + /** Remove an inner subscription. */ + void innerDone() { + Subscription s = null; + synchronized (guard) { + if (main != null) { + count--; + if (done && count == 0) { + s = main; + main = null; + } + } + } + if (s != null) { + s.unsubscribe(); + } + } + /** The individual sub-subscriptions. */ + class InnerSubscription implements Subscription { + final AtomicBoolean innerDone = new AtomicBoolean(); + @Override + public void unsubscribe() { + if (innerDone.compareAndSet(false, true)) { + innerDone(); + } + } + }; +} diff --git a/rxjava-core/src/test/java/rx/operators/OperationGroupJoinTest.java b/rxjava-core/src/test/java/rx/operators/OperationGroupJoinTest.java new file mode 100644 index 0000000000..cc28871c93 --- /dev/null +++ b/rxjava-core/src/test/java/rx/operators/OperationGroupJoinTest.java @@ -0,0 +1,344 @@ +/** + * Copyright 2013 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package rx.operators; + +import java.util.Arrays; +import org.junit.Before; +import org.junit.Test; +import static org.mockito.Matchers.any; +import org.mockito.Mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import org.mockito.MockitoAnnotations; +import rx.Observable; +import rx.Observer; +import rx.subjects.PublishSubject; +import rx.util.functions.Action1; +import rx.util.functions.Func1; +import rx.util.functions.Func2; + +public class OperationGroupJoinTest { + @Mock + Observer observer; + + Func2 add = new Func2() { + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + }; + Func1> just(final Observable observable) { + return new Func1>() { + @Override + public Observable call(Integer t1) { + return observable; + } + }; + } + Func1> just2(final Observable observable) { + return new Func1>() { + @Override + public Observable call(T t1) { + return observable; + } + }; + } + Func2, Observable> add2 = new Func2, Observable>() { + @Override + public Observable call(final Integer leftValue, Observable rightValues) { + return rightValues.map(new Func1() { + @Override + public Integer call(Integer rightValue) { + return add.call(leftValue, rightValue); + } + }); + } + + }; + @Before + public void before() { + MockitoAnnotations.initMocks(this); + } + @Test + public void behaveAsJoin() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable m = Observable.merge(source1.groupJoin(source2, + just(Observable.never()), + just(Observable.never()), add2)); + + m.subscribe(observer); + + source1.onNext(1); + source1.onNext(2); + source1.onNext(4); + + source2.onNext(16); + source2.onNext(32); + source2.onNext(64); + + source1.onCompleted(); + source2.onCompleted(); + + verify(observer, times(1)).onNext(17); + verify(observer, times(1)).onNext(18); + verify(observer, times(1)).onNext(20); + verify(observer, times(1)).onNext(33); + verify(observer, times(1)).onNext(34); + verify(observer, times(1)).onNext(36); + verify(observer, times(1)).onNext(65); + verify(observer, times(1)).onNext(66); + verify(observer, times(1)).onNext(68); + + verify(observer, times(1)).onCompleted(); //Never emitted? + verify(observer, never()).onError(any(Throwable.class)); + } + class Person { + final int id; + final String name; + public Person(int id, String name) { + this.id = id; + this.name = name; + } + } + class PersonFruit { + final int personId; + final String fruit; + public PersonFruit(int personId, String fruit) { + this.personId = personId; + this.fruit = fruit; + } + } + class PPF { + final Person person; + final Observable fruits; + public PPF(Person person, Observable fruits) { + this.person = person; + this.fruits = fruits; + } + } + @Test + public void normal1() { + Observable source1 = Observable.from(Arrays.asList( + new Person(1, "Joe"), + new Person(2, "Mike"), + new Person(3, "Charlie") + )); + + Observable source2 = Observable.from(Arrays.asList( + new PersonFruit(1, "Strawberry"), + new PersonFruit(1, "Apple"), + new PersonFruit(3, "Peach") + )); + + Observable q = source1.groupJoin( + source2, + just2(Observable.never()), + just2(Observable.never()), + new Func2, PPF>() { + @Override + public PPF call(Person t1, Observable t2) { + return new PPF(t1, t2); + } + }); + + q.subscribe( + new Observer() { + @Override + public void onNext(final PPF ppf) { + ppf.fruits.where(new Func1() { + @Override + public Boolean call(PersonFruit t1) { + return ppf.person.id == t1.personId; + } + }).subscribe(new Action1() { + @Override + public void call(PersonFruit t1) { + observer.onNext(Arrays.asList(ppf.person.name, t1.fruit)); + } + }); + } + + @Override + public void onError(Throwable e) { + observer.onError(e); + } + + @Override + public void onCompleted() { + observer.onCompleted(); + } + + } + ); + + verify(observer, times(1)).onNext(Arrays.asList("Joe", "Strawberry")); + verify(observer, times(1)).onNext(Arrays.asList("Joe", "Apple")); + verify(observer, times(1)).onNext(Arrays.asList("Charlie", "Peach")); + + verify(observer, times(1)).onCompleted(); + verify(observer, never()).onError(any(Throwable.class)); + } + @Test + public void leftThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable> m = source1.groupJoin(source2, + just(Observable.never()), + just(Observable.never()), add2); + + m.subscribe(observer); + + source2.onNext(1); + source1.onError(new RuntimeException("Forced failure")); + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable> m = source1.groupJoin(source2, + just(Observable.never()), + just(Observable.never()), add2); + + m.subscribe(observer); + + source1.onNext(1); + source2.onError(new RuntimeException("Forced failure")); + + verify(observer, times(1)).onNext(any(Observable.class)); + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + } + @Test + public void leftDurationThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable duration1 = Observable.error(new RuntimeException("Forced failure")); + + Observable> m = source1.groupJoin(source2, + just(duration1), + just(Observable.never()), add2); + m.subscribe(observer); + + source1.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightDurationThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Observable duration1 = Observable.error(new RuntimeException("Forced failure")); + + Observable> m = source1.groupJoin(source2, + just(Observable.never()), + just(duration1), add2); + m.subscribe(observer); + + source2.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void leftDurationSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func1> fail = new Func1>() { + @Override + public Observable call(Integer t1) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable> m = source1.groupJoin(source2, + fail, + just(Observable.never()), add2); + m.subscribe(observer); + + source1.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void rightDurationSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func1> fail = new Func1>() { + @Override + public Observable call(Integer t1) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable> m = source1.groupJoin(source2, + just(Observable.never()), + fail, add2); + m.subscribe(observer); + + source2.onNext(1); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } + @Test + public void resultSelectorThrows() { + PublishSubject source1 = PublishSubject.create(); + PublishSubject source2 = PublishSubject.create(); + + Func2, Integer> fail = new Func2, Integer>() { + @Override + public Integer call(Integer t1, Observable t2) { + throw new RuntimeException("Forced failure"); + } + }; + + Observable m = source1.groupJoin(source2, + just(Observable.never()), + just(Observable.never()), fail); + m.subscribe(observer); + + source1.onNext(1); + source2.onNext(2); + + + verify(observer, times(1)).onError(any(Throwable.class)); + verify(observer, never()).onCompleted(); + verify(observer, never()).onNext(any()); + } +}