1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.spark.network;
19 
20 import java.nio.ByteBuffer;
21 
22 import io.netty.channel.Channel;
23 import io.netty.channel.local.LocalChannel;
24 import org.junit.Test;
25 
26 import static org.junit.Assert.assertEquals;
27 import static org.mockito.Matchers.any;
28 import static org.mockito.Matchers.eq;
29 import static org.mockito.Mockito.*;
30 
31 import org.apache.spark.network.buffer.ManagedBuffer;
32 import org.apache.spark.network.buffer.NioManagedBuffer;
33 import org.apache.spark.network.client.ChunkReceivedCallback;
34 import org.apache.spark.network.client.RpcResponseCallback;
35 import org.apache.spark.network.client.StreamCallback;
36 import org.apache.spark.network.client.TransportResponseHandler;
37 import org.apache.spark.network.protocol.ChunkFetchFailure;
38 import org.apache.spark.network.protocol.ChunkFetchSuccess;
39 import org.apache.spark.network.protocol.RpcFailure;
40 import org.apache.spark.network.protocol.RpcResponse;
41 import org.apache.spark.network.protocol.StreamChunkId;
42 import org.apache.spark.network.protocol.StreamFailure;
43 import org.apache.spark.network.protocol.StreamResponse;
44 import org.apache.spark.network.util.TransportFrameDecoder;
45 
46 public class TransportResponseHandlerSuite {
47   @Test
handleSuccessfulFetch()48   public void handleSuccessfulFetch() throws Exception {
49     StreamChunkId streamChunkId = new StreamChunkId(1, 0);
50 
51     TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
52     ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
53     handler.addFetchRequest(streamChunkId, callback);
54     assertEquals(1, handler.numOutstandingRequests());
55 
56     handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123)));
57     verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
58     assertEquals(0, handler.numOutstandingRequests());
59   }
60 
61   @Test
handleFailedFetch()62   public void handleFailedFetch() throws Exception {
63     StreamChunkId streamChunkId = new StreamChunkId(1, 0);
64     TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
65     ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
66     handler.addFetchRequest(streamChunkId, callback);
67     assertEquals(1, handler.numOutstandingRequests());
68 
69     handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg"));
70     verify(callback, times(1)).onFailure(eq(0), (Throwable) any());
71     assertEquals(0, handler.numOutstandingRequests());
72   }
73 
74   @Test
clearAllOutstandingRequests()75   public void clearAllOutstandingRequests() throws Exception {
76     TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
77     ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
78     handler.addFetchRequest(new StreamChunkId(1, 0), callback);
79     handler.addFetchRequest(new StreamChunkId(1, 1), callback);
80     handler.addFetchRequest(new StreamChunkId(1, 2), callback);
81     assertEquals(3, handler.numOutstandingRequests());
82 
83     handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12)));
84     handler.exceptionCaught(new Exception("duh duh duhhhh"));
85 
86     // should fail both b2 and b3
87     verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any());
88     verify(callback, times(1)).onFailure(eq(1), (Throwable) any());
89     verify(callback, times(1)).onFailure(eq(2), (Throwable) any());
90     assertEquals(0, handler.numOutstandingRequests());
91   }
92 
93   @Test
handleSuccessfulRPC()94   public void handleSuccessfulRPC() throws Exception {
95     TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
96     RpcResponseCallback callback = mock(RpcResponseCallback.class);
97     handler.addRpcRequest(12345, callback);
98     assertEquals(1, handler.numOutstandingRequests());
99 
100     // This response should be ignored.
101     handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7))));
102     assertEquals(1, handler.numOutstandingRequests());
103 
104     ByteBuffer resp = ByteBuffer.allocate(10);
105     handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp)));
106     verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10)));
107     assertEquals(0, handler.numOutstandingRequests());
108   }
109 
110   @Test
handleFailedRPC()111   public void handleFailedRPC() throws Exception {
112     TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel());
113     RpcResponseCallback callback = mock(RpcResponseCallback.class);
114     handler.addRpcRequest(12345, callback);
115     assertEquals(1, handler.numOutstandingRequests());
116 
117     handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored
118     assertEquals(1, handler.numOutstandingRequests());
119 
120     handler.handle(new RpcFailure(12345, "oh no"));
121     verify(callback, times(1)).onFailure((Throwable) any());
122     assertEquals(0, handler.numOutstandingRequests());
123   }
124 
125   @Test
testActiveStreams()126   public void testActiveStreams() throws Exception {
127     Channel c = new LocalChannel();
128     c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
129     TransportResponseHandler handler = new TransportResponseHandler(c);
130 
131     StreamResponse response = new StreamResponse("stream", 1234L, null);
132     StreamCallback cb = mock(StreamCallback.class);
133     handler.addStreamCallback(cb);
134     assertEquals(1, handler.numOutstandingRequests());
135     handler.handle(response);
136     assertEquals(1, handler.numOutstandingRequests());
137     handler.deactivateStream();
138     assertEquals(0, handler.numOutstandingRequests());
139 
140     StreamFailure failure = new StreamFailure("stream", "uh-oh");
141     handler.addStreamCallback(cb);
142     assertEquals(1, handler.numOutstandingRequests());
143     handler.handle(failure);
144     assertEquals(0, handler.numOutstandingRequests());
145   }
146 }
147