001/**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *    http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.fusesource.hawtdispatch.transport;
019
020import org.fusesource.hawtdispatch.Task;
021
022import javax.net.ssl.*;
023import java.io.EOFException;
024import java.io.IOException;
025import java.nio.ByteBuffer;
026import java.nio.channels.GatheringByteChannel;
027import java.nio.channels.ReadableByteChannel;
028import java.nio.channels.ScatteringByteChannel;
029import java.nio.channels.WritableByteChannel;
030import java.security.cert.Certificate;
031import java.security.cert.X509Certificate;
032import java.util.ArrayList;
033
034import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*;
035import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
036
037/**
038 * Implements the SSL protocol as a WrappingProtocolCodec.  Useful for when
039 * you want to switch to the SSL protocol on a regular TCP Transport.
040 */
041public class SslProtocolCodec implements WrappingProtocolCodec, SecuredSession {
042
043    private ReadableByteChannel readChannel;
044    private WritableByteChannel writeChannel;
045
046    public enum ClientAuth {
047        WANT, NEED, NONE
048    };
049
050    private SSLContext sslContext;
051    private SSLEngine engine;
052
053    private ByteBuffer readBuffer;
054    private boolean readUnderflow;
055
056    private ByteBuffer writeBuffer;
057    private boolean writeFlushing;
058
059    private ByteBuffer readOverflowBuffer;
060    Transport transport;
061
062    int lastReadSize;
063    int lastWriteSize;
064    long readCounter;
065    long writeCounter;
066
067    ProtocolCodec next;
068
069
070    public SslProtocolCodec() {
071    }
072
073    public ProtocolCodec getNext() {
074        return next;
075    }
076    public void setNext(ProtocolCodec next) {
077        this.next = next;
078        initNext();
079    }
080
081    private void initNext() {
082        if( next!=null ) {
083            this.next.setTransport(new TransportFilter(transport){
084                public ReadableByteChannel getReadChannel() {
085                    return sslReadChannel;
086                }
087                public WritableByteChannel getWriteChannel() {
088                    return sslWriteChannel;
089                }
090            });
091        }
092    }
093
094    public void setSSLContext(SSLContext ctx) {
095        assert engine == null;
096        this.sslContext = ctx;
097    }
098
099    public SslProtocolCodec client() throws Exception {
100        initializeEngine();
101        engine.setUseClientMode(true);
102        engine.beginHandshake();
103        return this;
104    }
105
106    public SslProtocolCodec server(ClientAuth clientAuth) throws Exception {
107        initializeEngine();
108        engine.setUseClientMode(false);
109        switch (clientAuth) {
110            case WANT: engine.setWantClientAuth(true); break;
111            case NEED: engine.setNeedClientAuth(true); break;
112            case NONE: engine.setWantClientAuth(false); break;
113        }
114        engine.beginHandshake();
115        return this;
116    }
117
118    protected void initializeEngine() throws Exception {
119        assert engine == null;
120        if( sslContext == null ) {
121            sslContext = SSLContext.getDefault();
122        }
123        engine = sslContext.createSSLEngine();
124        SSLSession session = engine.getSession();
125        readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
126        readBuffer.flip();
127        writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
128    }
129
130
131    public SSLSession getSSLSession() {
132        return engine==null ? null : engine.getSession();
133    }
134
135    public X509Certificate[] getPeerX509Certificates() {
136        if( engine==null ) {
137            return null;
138        }
139        try {
140            ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
141            for( Certificate c:engine.getSession().getPeerCertificates() ) {
142                if(c instanceof X509Certificate) {
143                    rc.add((X509Certificate) c);
144                }
145            }
146            return rc.toArray(new X509Certificate[rc.size()]);
147        } catch (SSLPeerUnverifiedException e) {
148            return null;
149        }
150    }
151
152    SSLReadChannel sslReadChannel = new SSLReadChannel();
153    SSLWriteChannel sslWriteChannel = new SSLWriteChannel();
154
155    public void setTransport(Transport transport) {
156        this.transport = transport;
157        this.readChannel = transport.getReadChannel();
158        this.writeChannel = transport.getWriteChannel();
159        initNext();
160    }
161
162    public void handshake() throws IOException {
163        if( !transportFlush() ) {
164            return;
165        }
166        switch (engine.getHandshakeStatus()) {
167            case NEED_TASK:
168                final Runnable task = engine.getDelegatedTask();
169                if( task!=null ) {
170                    transport.getBlockingExecutor().execute(new Task() {
171                        public void run() {
172                            task.run();
173                            transport.getDispatchQueue().execute(new Task() {
174                                public void run() {
175                                    if (readChannel.isOpen() && writeChannel.isOpen()) {
176                                        try {
177                                            handshake();
178                                        } catch (IOException e) {
179                                            transport.getTransportListener().onTransportFailure(e);
180                                        }
181                                    }
182                                }
183                            });
184                        }
185                    });
186                }
187                break;
188
189            case NEED_WRAP:
190                secure_write(ByteBuffer.allocate(0));
191                break;
192
193            case NEED_UNWRAP:
194                if( secure_read(ByteBuffer.allocate(0)) == -1) {
195                    throw new EOFException("Peer disconnected during ssl handshake");
196                }
197                break;
198
199            case FINISHED:
200            case NOT_HANDSHAKING:
201                transport.drainInbound();
202                transport.getTransportListener().onRefill();
203                break;
204
205            default:
206                System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
207                break;
208        }
209    }
210
211    /**
212     * @return true if fully flushed.
213     * @throws IOException
214     */
215    protected boolean transportFlush() throws IOException {
216        while (true) {
217            if(writeFlushing) {
218                lastWriteSize = writeChannel.write(writeBuffer);
219                if( lastWriteSize > 0 ) {
220                    writeCounter += lastWriteSize;
221                }
222                if( !writeBuffer.hasRemaining() ) {
223                    writeBuffer.clear();
224                    writeFlushing = false;
225                    return true;
226                } else {
227                    return false;
228                }
229            } else {
230                if( writeBuffer.position()!=0 ) {
231                    writeBuffer.flip();
232                    writeFlushing = true;
233                } else {
234                    return true;
235                }
236            }
237        }
238    }
239
240    private int secure_read(ByteBuffer plain) throws IOException {
241        int rc=0;
242        while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
243            if( readOverflowBuffer !=null ) {
244                if(  plain.hasRemaining() ) {
245                    // lets drain the overflow buffer before trying to suck down anymore
246                    // network bytes.
247                    int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
248                    plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
249                    readOverflowBuffer.position(readOverflowBuffer.position()+size);
250                    if( !readOverflowBuffer.hasRemaining() ) {
251                        readOverflowBuffer = null;
252                    }
253                    rc += size;
254                } else {
255                    return rc;
256                }
257            } else if( readUnderflow ) {
258                lastReadSize = readChannel.read(readBuffer);
259                if( lastReadSize == -1 ) {  // peer closed socket.
260                    if (rc==0) {
261                        return -1;
262                    } else {
263                        return rc;
264                    }
265                }
266                if( lastReadSize==0 ) {  // no data available right now.
267                    return rc;
268                }
269                readCounter += lastReadSize;
270                // read in some more data, perhaps now we can unwrap.
271                readUnderflow = false;
272                readBuffer.flip();
273            } else {
274                SSLEngineResult result = engine.unwrap(readBuffer, plain);
275                rc += result.bytesProduced();
276                if( result.getStatus() == BUFFER_OVERFLOW ) {
277                    readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
278                    result = engine.unwrap(readBuffer, readOverflowBuffer);
279                    if( readOverflowBuffer.position()==0 ) {
280                        readOverflowBuffer = null;
281                    } else {
282                        readOverflowBuffer.flip();
283                    }
284                }
285                switch( result.getStatus() ) {
286                    case CLOSED:
287                        if (rc==0) {
288                            engine.closeInbound();
289                            return -1;
290                        } else {
291                            return rc;
292                        }
293                    case OK:
294                        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
295                            handshake();
296                        }
297                        break;
298                    case BUFFER_UNDERFLOW:
299                        readBuffer.compact();
300                        readUnderflow = true;
301                        break;
302                    case BUFFER_OVERFLOW:
303                        throw new AssertionError("Unexpected case.");
304                }
305            }
306        }
307        return rc;
308    }
309
310    private int secure_write(ByteBuffer plain) throws IOException {
311        if( !transportFlush() ) {
312            // can't write anymore until the write_secured_buffer gets fully flushed out..
313            return 0;
314        }
315        int rc = 0;
316        while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
317            SSLEngineResult result = engine.wrap(plain, writeBuffer);
318            assert result.getStatus()!= BUFFER_OVERFLOW;
319            rc += result.bytesConsumed();
320            if( !transportFlush() ) {
321                break;
322            }
323        }
324        if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
325            handshake();
326        }
327        return rc;
328    }
329
330    public class SSLReadChannel implements ScatteringByteChannel {
331
332        public int read(ByteBuffer plain) throws IOException {
333            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
334                handshake();
335            }
336            return secure_read(plain);
337        }
338
339        public boolean isOpen() {
340            return readChannel.isOpen();
341        }
342
343        public void close() throws IOException {
344            readChannel.close();
345        }
346
347        public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
348            if(offset+length > dsts.length || length<0 || offset<0) {
349                throw new IndexOutOfBoundsException();
350            }
351            long rc=0;
352            for (int i = 0; i < length; i++) {
353                ByteBuffer dst = dsts[offset+i];
354                if(dst.hasRemaining()) {
355                    rc += read(dst);
356                }
357                if( dst.hasRemaining() ) {
358                    return rc;
359                }
360            }
361            return rc;
362        }
363
364        public long read(ByteBuffer[] dsts) throws IOException {
365            return read(dsts, 0, dsts.length);
366        }
367    }
368
369    public class SSLWriteChannel implements GatheringByteChannel {
370
371        public int write(ByteBuffer plain) throws IOException {
372            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
373                handshake();
374            }
375            return secure_write(plain);
376        }
377
378        public boolean isOpen() {
379            return writeChannel.isOpen();
380        }
381
382        public void close() throws IOException {
383            writeChannel.close();
384        }
385
386        public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
387            if(offset+length > srcs.length || length<0 || offset<0) {
388                throw new IndexOutOfBoundsException();
389            }
390            long rc=0;
391            for (int i = 0; i < length; i++) {
392                ByteBuffer src = srcs[offset+i];
393                if(src.hasRemaining()) {
394                    rc += write(src);
395                }
396                if( src.hasRemaining() ) {
397                    return rc;
398                }
399            }
400            return rc;
401        }
402
403        public long write(ByteBuffer[] srcs) throws IOException {
404            return write(srcs, 0, srcs.length);
405        }
406    }
407
408    public void unread(byte[] buffer) {
409        readBuffer.compact();
410        if( readBuffer.remaining() < buffer.length) {
411            throw new IllegalStateException("Cannot unread now");
412        }
413        readBuffer.put(buffer);
414        readBuffer.flip();
415    }
416
417    public Object read() throws IOException {
418        return next.read();
419    }
420
421    public ProtocolCodec.BufferState write(Object value) throws IOException {
422        return next.write(value);
423    }
424
425    public ProtocolCodec.BufferState flush() throws IOException {
426        return next.flush();
427    }
428
429    public boolean full() {
430        return next.full();
431    }
432
433    public long getWriteCounter() {
434        return writeCounter;
435    }
436
437    public long getLastWriteSize() {
438        return lastWriteSize;
439    }
440
441    public long getReadCounter() {
442        return readCounter;
443    }
444
445    public long getLastReadSize() {
446        return lastReadSize;
447    }
448
449    public int getReadBufferSize() {
450        return readBuffer.capacity();
451    }
452
453    public int getWriteBufferSize() {
454        return writeBuffer.capacity();
455    }
456
457
458
459}