1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.spark.network.protocol;
19 
20 import java.io.IOException;
21 import java.nio.ByteBuffer;
22 import java.nio.channels.WritableByteChannel;
23 
24 import io.netty.buffer.ByteBuf;
25 import io.netty.buffer.Unpooled;
26 import io.netty.channel.FileRegion;
27 import io.netty.util.AbstractReferenceCounted;
28 import org.junit.Test;
29 import org.mockito.Mockito;
30 
31 import static org.junit.Assert.*;
32 
33 import org.apache.spark.network.TestManagedBuffer;
34 import org.apache.spark.network.buffer.ManagedBuffer;
35 import org.apache.spark.network.buffer.NettyManagedBuffer;
36 import org.apache.spark.network.util.ByteArrayWritableChannel;
37 
38 public class MessageWithHeaderSuite {
39 
40   @Test
testSingleWrite()41   public void testSingleWrite() throws Exception {
42     testFileRegionBody(8, 8);
43   }
44 
45   @Test
testShortWrite()46   public void testShortWrite() throws Exception {
47     testFileRegionBody(8, 1);
48   }
49 
50   @Test
testByteBufBody()51   public void testByteBufBody() throws Exception {
52     ByteBuf header = Unpooled.copyLong(42);
53     ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
54     assertEquals(1, header.refCnt());
55     assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
56     ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
57 
58     Object body = managedBuf.convertToNetty();
59     assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
60     assertEquals(1, header.refCnt());
61 
62     MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
63     ByteBuf result = doWrite(msg, 1);
64     assertEquals(msg.count(), result.readableBytes());
65     assertEquals(42, result.readLong());
66     assertEquals(84, result.readLong());
67 
68     assertTrue(msg.release());
69     assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
70     assertEquals(0, header.refCnt());
71   }
72 
73   @Test
testDeallocateReleasesManagedBuffer()74   public void testDeallocateReleasesManagedBuffer() throws Exception {
75     ByteBuf header = Unpooled.copyLong(42);
76     ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
77     ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
78     assertEquals(2, body.refCnt());
79     MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
80     assertTrue(msg.release());
81     Mockito.verify(managedBuf, Mockito.times(1)).release();
82     assertEquals(0, body.refCnt());
83   }
84 
testFileRegionBody(int totalWrites, int writesPerCall)85   private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
86     ByteBuf header = Unpooled.copyLong(42);
87     int headerLength = header.readableBytes();
88     TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
89     MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
90 
91     ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
92     assertEquals(headerLength + region.count(), result.readableBytes());
93     assertEquals(42, result.readLong());
94     for (long i = 0; i < 8; i++) {
95       assertEquals(i, result.readLong());
96     }
97     assertTrue(msg.release());
98   }
99 
doWrite(MessageWithHeader msg, int minExpectedWrites)100   private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
101     int writes = 0;
102     ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
103     while (msg.transfered() < msg.count()) {
104       msg.transferTo(channel, msg.transfered());
105       writes++;
106     }
107     assertTrue("Not enough writes!", minExpectedWrites <= writes);
108     return Unpooled.wrappedBuffer(channel.getData());
109   }
110 
111   private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
112 
113     private final int writeCount;
114     private final int writesPerCall;
115     private int written;
116 
TestFileRegion(int totalWrites, int writesPerCall)117     TestFileRegion(int totalWrites, int writesPerCall) {
118       this.writeCount = totalWrites;
119       this.writesPerCall = writesPerCall;
120     }
121 
122     @Override
count()123     public long count() {
124       return 8 * writeCount;
125     }
126 
127     @Override
position()128     public long position() {
129       return 0;
130     }
131 
132     @Override
transfered()133     public long transfered() {
134       return 8 * written;
135     }
136 
137     @Override
transferTo(WritableByteChannel target, long position)138     public long transferTo(WritableByteChannel target, long position) throws IOException {
139       for (int i = 0; i < writesPerCall; i++) {
140         ByteBuf buf = Unpooled.copyLong((position / 8) + i);
141         ByteBuffer nio = buf.nioBuffer();
142         while (nio.remaining() > 0) {
143           target.write(nio);
144         }
145         buf.release();
146         written++;
147       }
148       return 8 * writesPerCall;
149     }
150 
151     @Override
deallocate()152     protected void deallocate() {
153     }
154 
155   }
156 
157 }
158