001/**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *    http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.fusesource.hawtdispatch.transport;
019
020import org.fusesource.hawtbuf.Buffer;
021import org.fusesource.hawtbuf.DataByteArrayOutputStream;
022import org.fusesource.hawtdispatch.util.BufferPool;
023import org.fusesource.hawtdispatch.util.BufferPools;
024
025import java.io.EOFException;
026import java.io.IOException;
027import java.net.ProtocolException;
028import java.net.SocketException;
029import java.nio.ByteBuffer;
030import java.nio.channels.GatheringByteChannel;
031import java.nio.channels.ReadableByteChannel;
032import java.nio.channels.SocketChannel;
033import java.util.Arrays;
034import java.util.LinkedList;
035
036/**
037 * Provides an abstract base class to make implementing the ProtocolCodec interface
038 * easier.
039 *
040 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
041 */
042public abstract class AbstractProtocolCodec implements ProtocolCodec {
043
044    protected BufferPools bufferPools;
045    protected BufferPool writeBufferPool;
046    protected BufferPool readBufferPool;
047
048    protected int writeBufferSize = 1024 * 64;
049    protected long writeCounter = 0L;
050    protected GatheringByteChannel writeChannel = null;
051    protected DataByteArrayOutputStream nextWriteBuffer;
052    protected long lastWriteIoSize = 0;
053
054    protected LinkedList<ByteBuffer> writeBuffer = new LinkedList<ByteBuffer>();
055    private long writeBufferRemaining = 0;
056
057
058    public static interface Action {
059        Object apply() throws IOException;
060    }
061
062    protected long readCounter = 0L;
063    protected int readBufferSize = 1024 * 64;
064    protected ReadableByteChannel readChannel = null;
065    protected ByteBuffer readBuffer;
066    protected ByteBuffer directReadBuffer = null;
067
068    protected int readEnd;
069    protected int readStart;
070    protected int lastReadIoSize;
071    protected Action nextDecodeAction;
072
073    public void setTransport(Transport transport) {
074        this.writeChannel = (GatheringByteChannel) transport.getWriteChannel();
075        this.readChannel = transport.getReadChannel();
076        if( nextDecodeAction==null ) {
077            nextDecodeAction = initialDecodeAction();
078        }
079        if( transport instanceof TcpTransport) {
080            TcpTransport tcp = (TcpTransport) transport;
081            writeBufferSize = tcp.getSendBufferSize();
082            readBufferSize = tcp.getReceiveBufferSize();
083        } else if( transport instanceof UdpTransport) {
084            UdpTransport tcp = (UdpTransport) transport;
085            writeBufferSize = tcp.getSendBufferSize();
086            readBufferSize = tcp.getReceiveBufferSize();
087        } else {
088            try {
089                if (this.writeChannel instanceof SocketChannel) {
090                    writeBufferSize = ((SocketChannel) this.writeChannel).socket().getSendBufferSize();
091                    readBufferSize = ((SocketChannel) this.readChannel).socket().getReceiveBufferSize();
092                } else if (this.writeChannel instanceof SslTransport.SSLChannel) {
093                    writeBufferSize = ((SslTransport.SSLChannel) this.readChannel).socket().getSendBufferSize();
094                    readBufferSize = ((SslTransport.SSLChannel) this.writeChannel).socket().getReceiveBufferSize();
095                }
096            } catch (SocketException ignore) {
097            }
098        }
099        if( bufferPools!=null ) {
100            readBufferPool = bufferPools.getBufferPool(readBufferSize);
101            writeBufferPool = bufferPools.getBufferPool(writeBufferSize);
102        }
103    }
104
105    public int getReadBufferSize() {
106        return readBufferSize;
107    }
108
109    public int getWriteBufferSize() {
110        return writeBufferSize;
111    }
112
113    public boolean full() {
114        return writeBufferRemaining >= writeBufferSize;
115    }
116
117    public boolean isEmpty() {
118        return writeBufferRemaining == 0 && (nextWriteBuffer==null || nextWriteBuffer.size() == 0);
119    }
120
121    public long getWriteCounter() {
122        return writeCounter;
123    }
124
125    public long getLastWriteSize() {
126        return lastWriteIoSize;
127    }
128
129    abstract protected void encode(Object value) throws IOException;
130
131    public ProtocolCodec.BufferState write(Object value) throws IOException {
132        if (full()) {
133            return ProtocolCodec.BufferState.FULL;
134        } else {
135            boolean wasEmpty = isEmpty();
136            if( nextWriteBuffer == null ) {
137                nextWriteBuffer = allocateNextWriteBuffer();
138            }
139            encode(value);
140            if (nextWriteBuffer.size() >= (writeBufferSize* 0.75)) {
141                flushNextWriteBuffer();
142            }
143            if (wasEmpty) {
144                return ProtocolCodec.BufferState.WAS_EMPTY;
145            } else {
146                return ProtocolCodec.BufferState.NOT_EMPTY;
147            }
148        }
149    }
150
151    private DataByteArrayOutputStream allocateNextWriteBuffer() {
152        if( writeBufferPool !=null ) {
153            return new DataByteArrayOutputStream(writeBufferPool.checkout()) {
154                @Override
155                protected void resize(int newcount) {
156                    byte[] oldbuf = buf;
157                    super.resize(newcount);
158                    if( oldbuf.length == writeBufferPool.getBufferSize() ) {
159                        writeBufferPool.checkin(oldbuf);
160                    }
161                }
162            };
163        } else {
164            return new DataByteArrayOutputStream(writeBufferSize);
165        }
166    }
167
168    protected void writeDirect(ByteBuffer value) throws IOException {
169        // is the direct buffer small enough to just fit into the nextWriteBuffer?
170        int nextnextPospos = nextWriteBuffer.position();
171        int valuevalueLengthlength = value.remaining();
172        int available = nextWriteBuffer.getData().length - nextnextPospos;
173        if (available > valuevalueLengthlength) {
174            value.get(nextWriteBuffer.getData(), nextnextPospos, valuevalueLengthlength);
175            nextWriteBuffer.position(nextnextPospos + valuevalueLengthlength);
176        } else {
177            if (nextWriteBuffer!=null && nextWriteBuffer.size() != 0) {
178                flushNextWriteBuffer();
179            }
180            writeBuffer.add(value);
181            writeBufferRemaining += value.remaining();
182        }
183    }
184
185    protected void flushNextWriteBuffer() {
186        DataByteArrayOutputStream next = allocateNextWriteBuffer();
187        ByteBuffer bb = nextWriteBuffer.toBuffer().toByteBuffer();
188        writeBuffer.add(bb);
189        writeBufferRemaining += bb.remaining();
190        nextWriteBuffer = next;
191    }
192
193    public ProtocolCodec.BufferState flush() throws IOException {
194        while (true) {
195            if (writeBufferRemaining != 0) {
196                if( writeBuffer.size() == 1) {
197                    ByteBuffer b = writeBuffer.getFirst();
198                    lastWriteIoSize = writeChannel.write(b);
199                    if (lastWriteIoSize == 0) {
200                        return ProtocolCodec.BufferState.NOT_EMPTY;
201                    } else {
202                        writeBufferRemaining -= lastWriteIoSize;
203                        writeCounter += lastWriteIoSize;
204                        if(!b.hasRemaining()) {
205                            onBufferFlushed(writeBuffer.removeFirst());
206                        }
207                    }
208                } else {
209                    ByteBuffer[] buffers = writeBuffer.toArray(new ByteBuffer[writeBuffer.size()]);
210                    lastWriteIoSize = writeChannel.write(buffers, 0, buffers.length);
211                    if (lastWriteIoSize == 0) {
212                        return ProtocolCodec.BufferState.NOT_EMPTY;
213                    } else {
214                        writeBufferRemaining -= lastWriteIoSize;
215                        writeCounter += lastWriteIoSize;
216                        while (!writeBuffer.isEmpty() && !writeBuffer.getFirst().hasRemaining()) {
217                            onBufferFlushed(writeBuffer.removeFirst());
218                        }
219                    }
220                }
221            } else {
222                if (nextWriteBuffer==null || nextWriteBuffer.size() == 0) {
223                    if( writeBufferPool!=null &&  nextWriteBuffer!=null ) {
224                        writeBufferPool.checkin(nextWriteBuffer.getData());
225                        nextWriteBuffer = null;
226                    }
227                    return ProtocolCodec.BufferState.EMPTY;
228                } else {
229                    flushNextWriteBuffer();
230                }
231            }
232        }
233    }
234
235    /**
236     * Called when a buffer is flushed out.  Subclasses can implement
237     * in case they want to recycle the buffer.
238     *
239     * @param byteBuffer
240     */
241    protected void onBufferFlushed(ByteBuffer byteBuffer) {
242    }
243
244    /////////////////////////////////////////////////////////////////////
245    //
246    // Non blocking read impl
247    //
248    /////////////////////////////////////////////////////////////////////
249
250    abstract protected Action initialDecodeAction();
251
252
253    public void unread(byte[] buffer) {
254        assert ((readCounter == 0));
255        readBuffer = ByteBuffer.allocate(buffer.length);
256        readBuffer.put(buffer);
257        readCounter += buffer.length;
258    }
259
260    public long getReadCounter() {
261        return readCounter;
262    }
263
264    public long getLastReadSize() {
265        return lastReadIoSize;
266    }
267
268    public Object read() throws IOException {
269        Object command = null;
270        while (command == null) {
271            if (directReadBuffer != null) {
272                while (directReadBuffer.hasRemaining()) {
273                    lastReadIoSize = readChannel.read(directReadBuffer);
274                    readCounter += lastReadIoSize;
275                    if (lastReadIoSize == -1) {
276                        throw new EOFException("Peer disconnected");
277                    } else if (lastReadIoSize == 0) {
278                        return null;
279                    }
280                }
281                command = nextDecodeAction.apply();
282            } else {
283                if (readBuffer==null || readEnd >= readBuffer.position()) {
284
285                    int readPos = 0;
286                    boolean candidateForCheckin = false;
287                    if( readBuffer!=null ) {
288                        readPos = readBuffer.position();
289                        candidateForCheckin = readBufferPool!=null && readStart == 0 && readBuffer.capacity() == readBufferPool.getBufferSize();
290                    }
291
292                    if (readBuffer==null || readBuffer.remaining() == 0) {
293
294
295                        int loadedSize = readPos - readStart;
296                        int neededSize = readEnd - readStart;
297
298                        int newSize = 0;
299                        if( neededSize > loadedSize ) {
300                            newSize =  Math.max(readBufferSize, neededSize);
301                        } else {
302                            newSize = loadedSize+readBufferSize;
303                        }
304
305                        byte[] newBuffer;
306                        if (loadedSize > 0) {
307                            newBuffer = Arrays.copyOfRange(readBuffer.array(), readStart, readStart + newSize);
308                        } else {
309                            if( readBufferPool!=null && newSize == readBufferPool.getBufferSize()) {
310                                newBuffer = readBufferPool.checkout();
311                            } else {
312                                newBuffer =  new byte[newSize];
313                            }
314                        }
315
316                        if( candidateForCheckin ) {
317                            readBufferPool.checkin(readBuffer.array());
318                        }
319
320                        readBuffer = ByteBuffer.wrap(newBuffer);
321                        readBuffer.position(loadedSize);
322                        readStart = 0;
323                        readEnd = neededSize;
324                    }
325
326                    lastReadIoSize = readChannel.read(readBuffer);
327
328                    readCounter += lastReadIoSize;
329                    if (lastReadIoSize == -1) {
330                        readCounter += 1; // to compensate for that -1
331                        throw new EOFException("Peer disconnected");
332                    } else if (lastReadIoSize == 0) {
333                        if ( readStart == readBuffer.position() ) {
334                            if (candidateForCheckin) {
335                                readBufferPool.checkin(readBuffer.array());
336                            }
337                            readStart = 0;
338                            readEnd = 0;
339                            readBuffer = null;
340                        }
341                        return null;
342                    }
343
344                    // if we did not read a full buffer.. then resize the buffer
345                    if( readBuffer.hasRemaining() && readEnd <= readBuffer.position() ) {
346                        ByteBuffer perfectSized = ByteBuffer.wrap(Arrays.copyOfRange(readBuffer.array(), 0, readBuffer.position()));
347                        perfectSized.position(readBuffer.position());
348
349                        if( candidateForCheckin ) {
350                            readBufferPool.checkin(readBuffer.array());
351                        }
352                        readBuffer = perfectSized;
353                    }
354                }
355                command = nextDecodeAction.apply();
356                assert ((readStart <= readEnd));
357            }
358        }
359        return command;
360    }
361
362    protected Buffer readUntil(Byte octet) throws ProtocolException {
363        return readUntil(octet, -1);
364    }
365
366    protected Buffer readUntil(Byte octet, int max) throws ProtocolException {
367        return readUntil(octet, max, "Maximum protocol buffer length exeeded");
368    }
369
370    protected Buffer readUntil(Byte octet, int max, String msg) throws ProtocolException {
371        byte[] array = readBuffer.array();
372        Buffer buf = new Buffer(array, readEnd, readBuffer.position() - readEnd);
373        int pos = buf.indexOf(octet);
374        if (pos >= 0) {
375            int offset = readStart;
376            readEnd += pos + 1;
377            readStart = readEnd;
378            int length = readEnd - offset;
379            if (max >= 0 && length > max) {
380                throw new ProtocolException(msg);
381            }
382            return new Buffer(array, offset, length);
383        } else {
384            readEnd += buf.length;
385            if (max >= 0 && (readEnd - readStart) > max) {
386                throw new ProtocolException(msg);
387            }
388            return null;
389        }
390    }
391
392    protected Buffer readBytes(int length) {
393        readEnd = readStart + length;
394        if (readBuffer.position() < readEnd) {
395            return null;
396        } else {
397            int offset = readStart;
398            readStart = readEnd;
399            return new Buffer(readBuffer.array(), offset, length);
400        }
401    }
402
403    protected Buffer peekBytes(int length) {
404        readEnd = readStart + length;
405        if (readBuffer.position() < readEnd) {
406            return null;
407        } else {
408            // rewind..
409            readEnd = readStart;
410            return new Buffer(readBuffer.array(), readStart, length);
411        }
412    }
413
414    protected Boolean readDirect(ByteBuffer buffer) {
415        assert (directReadBuffer == null || (directReadBuffer == buffer));
416
417        if (buffer.hasRemaining()) {
418            // First we need to transfer the read bytes from the non-direct
419            // byte buffer into the direct one..
420            int limit = readBuffer.position();
421            int transferSize = Math.min((limit - readStart), buffer.remaining());
422            byte[] readBufferArray = readBuffer.array();
423            buffer.put(readBufferArray, readStart, transferSize);
424
425            // The direct byte buffer might have been smaller than our readBuffer one..
426            // compact the readBuffer to avoid doing additional mem allocations.
427            int trailingSize = limit - (readStart + transferSize);
428            if (trailingSize > 0) {
429                System.arraycopy(readBufferArray, readStart + transferSize, readBufferArray, readStart, trailingSize);
430            }
431            readBuffer.position(readStart + trailingSize);
432        }
433
434        // For big direct byte buffers, it will still not have been filled,
435        // so install it so that we directly read into it until it is filled.
436        if (buffer.hasRemaining()) {
437            directReadBuffer = buffer;
438            return false;
439        } else {
440            directReadBuffer = null;
441            buffer.flip();
442            return true;
443        }
444    }
445
446    public BufferPools getBufferPools() {
447        return bufferPools;
448    }
449
450    public void setBufferPools(BufferPools bufferPools) {
451        this.bufferPools = bufferPools;
452        if( bufferPools!=null ) {
453            readBufferPool = bufferPools.getBufferPool(readBufferSize);
454            writeBufferPool = bufferPools.getBufferPool(writeBufferSize);
455        } else {
456            readBufferPool = null;
457            writeBufferPool = null;
458        }
459    }
460}