1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.spark.network; 19 20 import java.nio.ByteBuffer; 21 22 import io.netty.channel.Channel; 23 import io.netty.channel.local.LocalChannel; 24 import org.junit.Test; 25 26 import static org.junit.Assert.assertEquals; 27 import static org.mockito.Matchers.any; 28 import static org.mockito.Matchers.eq; 29 import static org.mockito.Mockito.*; 30 31 import org.apache.spark.network.buffer.ManagedBuffer; 32 import org.apache.spark.network.buffer.NioManagedBuffer; 33 import org.apache.spark.network.client.ChunkReceivedCallback; 34 import org.apache.spark.network.client.RpcResponseCallback; 35 import org.apache.spark.network.client.StreamCallback; 36 import org.apache.spark.network.client.TransportResponseHandler; 37 import org.apache.spark.network.protocol.ChunkFetchFailure; 38 import org.apache.spark.network.protocol.ChunkFetchSuccess; 39 import org.apache.spark.network.protocol.RpcFailure; 40 import org.apache.spark.network.protocol.RpcResponse; 41 import org.apache.spark.network.protocol.StreamChunkId; 42 import org.apache.spark.network.protocol.StreamFailure; 43 import org.apache.spark.network.protocol.StreamResponse; 44 import org.apache.spark.network.util.TransportFrameDecoder; 45 46 public class TransportResponseHandlerSuite { 47 @Test handleSuccessfulFetch()48 public void handleSuccessfulFetch() throws Exception { 49 StreamChunkId streamChunkId = new StreamChunkId(1, 0); 50 51 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); 52 ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); 53 handler.addFetchRequest(streamChunkId, callback); 54 assertEquals(1, handler.numOutstandingRequests()); 55 56 handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); 57 verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); 58 assertEquals(0, handler.numOutstandingRequests()); 59 } 60 61 @Test handleFailedFetch()62 public void handleFailedFetch() throws Exception { 63 StreamChunkId streamChunkId = new StreamChunkId(1, 0); 64 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); 65 ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); 66 handler.addFetchRequest(streamChunkId, callback); 67 assertEquals(1, handler.numOutstandingRequests()); 68 69 handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); 70 verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); 71 assertEquals(0, handler.numOutstandingRequests()); 72 } 73 74 @Test clearAllOutstandingRequests()75 public void clearAllOutstandingRequests() throws Exception { 76 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); 77 ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); 78 handler.addFetchRequest(new StreamChunkId(1, 0), callback); 79 handler.addFetchRequest(new StreamChunkId(1, 1), callback); 80 handler.addFetchRequest(new StreamChunkId(1, 2), callback); 81 assertEquals(3, handler.numOutstandingRequests()); 82 83 handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); 84 handler.exceptionCaught(new Exception("duh duh duhhhh")); 85 86 // should fail both b2 and b3 87 verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); 88 verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); 89 verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); 90 assertEquals(0, handler.numOutstandingRequests()); 91 } 92 93 @Test handleSuccessfulRPC()94 public void handleSuccessfulRPC() throws Exception { 95 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); 96 RpcResponseCallback callback = mock(RpcResponseCallback.class); 97 handler.addRpcRequest(12345, callback); 98 assertEquals(1, handler.numOutstandingRequests()); 99 100 // This response should be ignored. 101 handler.handle(new RpcResponse(54321, new NioManagedBuffer(ByteBuffer.allocate(7)))); 102 assertEquals(1, handler.numOutstandingRequests()); 103 104 ByteBuffer resp = ByteBuffer.allocate(10); 105 handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp))); 106 verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10))); 107 assertEquals(0, handler.numOutstandingRequests()); 108 } 109 110 @Test handleFailedRPC()111 public void handleFailedRPC() throws Exception { 112 TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); 113 RpcResponseCallback callback = mock(RpcResponseCallback.class); 114 handler.addRpcRequest(12345, callback); 115 assertEquals(1, handler.numOutstandingRequests()); 116 117 handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored 118 assertEquals(1, handler.numOutstandingRequests()); 119 120 handler.handle(new RpcFailure(12345, "oh no")); 121 verify(callback, times(1)).onFailure((Throwable) any()); 122 assertEquals(0, handler.numOutstandingRequests()); 123 } 124 125 @Test testActiveStreams()126 public void testActiveStreams() throws Exception { 127 Channel c = new LocalChannel(); 128 c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); 129 TransportResponseHandler handler = new TransportResponseHandler(c); 130 131 StreamResponse response = new StreamResponse("stream", 1234L, null); 132 StreamCallback cb = mock(StreamCallback.class); 133 handler.addStreamCallback(cb); 134 assertEquals(1, handler.numOutstandingRequests()); 135 handler.handle(response); 136 assertEquals(1, handler.numOutstandingRequests()); 137 handler.deactivateStream(); 138 assertEquals(0, handler.numOutstandingRequests()); 139 140 StreamFailure failure = new StreamFailure("stream", "uh-oh"); 141 handler.addStreamCallback(cb); 142 assertEquals(1, handler.numOutstandingRequests()); 143 handler.handle(failure); 144 assertEquals(0, handler.numOutstandingRequests()); 145 } 146 } 147