1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.spark.network.protocol; 19 20 import java.io.IOException; 21 import java.nio.ByteBuffer; 22 import java.nio.channels.WritableByteChannel; 23 24 import io.netty.buffer.ByteBuf; 25 import io.netty.buffer.Unpooled; 26 import io.netty.channel.FileRegion; 27 import io.netty.util.AbstractReferenceCounted; 28 import org.junit.Test; 29 import org.mockito.Mockito; 30 31 import static org.junit.Assert.*; 32 33 import org.apache.spark.network.TestManagedBuffer; 34 import org.apache.spark.network.buffer.ManagedBuffer; 35 import org.apache.spark.network.buffer.NettyManagedBuffer; 36 import org.apache.spark.network.util.ByteArrayWritableChannel; 37 38 public class MessageWithHeaderSuite { 39 40 @Test testSingleWrite()41 public void testSingleWrite() throws Exception { 42 testFileRegionBody(8, 8); 43 } 44 45 @Test testShortWrite()46 public void testShortWrite() throws Exception { 47 testFileRegionBody(8, 1); 48 } 49 50 @Test testByteBufBody()51 public void testByteBufBody() throws Exception { 52 ByteBuf header = Unpooled.copyLong(42); 53 ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); 54 assertEquals(1, header.refCnt()); 55 assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); 56 ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer); 57 58 Object body = managedBuf.convertToNetty(); 59 assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt()); 60 assertEquals(1, header.refCnt()); 61 62 MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); 63 ByteBuf result = doWrite(msg, 1); 64 assertEquals(msg.count(), result.readableBytes()); 65 assertEquals(42, result.readLong()); 66 assertEquals(84, result.readLong()); 67 68 assertTrue(msg.release()); 69 assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt()); 70 assertEquals(0, header.refCnt()); 71 } 72 73 @Test testDeallocateReleasesManagedBuffer()74 public void testDeallocateReleasesManagedBuffer() throws Exception { 75 ByteBuf header = Unpooled.copyLong(42); 76 ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84)); 77 ByteBuf body = (ByteBuf) managedBuf.convertToNetty(); 78 assertEquals(2, body.refCnt()); 79 MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); 80 assertTrue(msg.release()); 81 Mockito.verify(managedBuf, Mockito.times(1)).release(); 82 assertEquals(0, body.refCnt()); 83 } 84 testFileRegionBody(int totalWrites, int writesPerCall)85 private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { 86 ByteBuf header = Unpooled.copyLong(42); 87 int headerLength = header.readableBytes(); 88 TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); 89 MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count()); 90 91 ByteBuf result = doWrite(msg, totalWrites / writesPerCall); 92 assertEquals(headerLength + region.count(), result.readableBytes()); 93 assertEquals(42, result.readLong()); 94 for (long i = 0; i < 8; i++) { 95 assertEquals(i, result.readLong()); 96 } 97 assertTrue(msg.release()); 98 } 99 doWrite(MessageWithHeader msg, int minExpectedWrites)100 private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { 101 int writes = 0; 102 ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count()); 103 while (msg.transfered() < msg.count()) { 104 msg.transferTo(channel, msg.transfered()); 105 writes++; 106 } 107 assertTrue("Not enough writes!", minExpectedWrites <= writes); 108 return Unpooled.wrappedBuffer(channel.getData()); 109 } 110 111 private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion { 112 113 private final int writeCount; 114 private final int writesPerCall; 115 private int written; 116 TestFileRegion(int totalWrites, int writesPerCall)117 TestFileRegion(int totalWrites, int writesPerCall) { 118 this.writeCount = totalWrites; 119 this.writesPerCall = writesPerCall; 120 } 121 122 @Override count()123 public long count() { 124 return 8 * writeCount; 125 } 126 127 @Override position()128 public long position() { 129 return 0; 130 } 131 132 @Override transfered()133 public long transfered() { 134 return 8 * written; 135 } 136 137 @Override transferTo(WritableByteChannel target, long position)138 public long transferTo(WritableByteChannel target, long position) throws IOException { 139 for (int i = 0; i < writesPerCall; i++) { 140 ByteBuf buf = Unpooled.copyLong((position / 8) + i); 141 ByteBuffer nio = buf.nioBuffer(); 142 while (nio.remaining() > 0) { 143 target.write(nio); 144 } 145 buf.release(); 146 written++; 147 } 148 return 8 * writesPerCall; 149 } 150 151 @Override deallocate()152 protected void deallocate() { 153 } 154 155 } 156 157 } 158