001/**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *    http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.fusesource.hawtdispatch.transport;
019
020import org.fusesource.hawtdispatch.Task;
021
022import javax.net.ssl.*;
023import java.io.EOFException;
024import java.io.IOException;
025import java.net.Socket;
026import java.net.URI;
027import java.nio.ByteBuffer;
028import java.nio.channels.*;
029import java.security.cert.Certificate;
030import java.security.cert.X509Certificate;
031import java.util.ArrayList;
032
033import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*;
034import static javax.net.ssl.SSLEngineResult.Status.*;
035
036/**
037 * An SSL Transport for secure communications.
038 *
039 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
040 */
041public class SslTransport extends TcpTransport implements SecuredSession {
042
043    /**
044     * Maps uri schemes to a protocol algorithm names.
045     * Valid algorithm names listed at:
046     * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
047     */
048    public static String protocol(String scheme) {
049        if( scheme.equals("tls") ) {
050            return "TLS";
051        } else if( scheme.startsWith("tlsv") ) {
052            return "TLSv"+scheme.substring(4);
053        } else if( scheme.equals("ssl") ) {
054            return "SSL";
055        } else if( scheme.startsWith("sslv") ) {
056            return "SSLv"+scheme.substring(4);
057        }
058        return null;
059    }
060
061    enum ClientAuth {
062        WANT, NEED, NONE
063    };
064
065    private ClientAuth clientAuth = ClientAuth.WANT;
066    private String disabledCypherSuites = null;
067    private String enabledCipherSuites = null;
068    
069    private SSLContext sslContext;
070    private SSLEngine engine;
071
072    private ByteBuffer readBuffer;
073    private boolean readUnderflow;
074
075    private ByteBuffer writeBuffer;
076    private boolean writeFlushing;
077
078    private ByteBuffer readOverflowBuffer;
079    private SSLChannel ssl_channel = new SSLChannel();
080
081
082    public void setSSLContext(SSLContext ctx) {
083        this.sslContext = ctx;
084    }
085
086    /**
087     * Allows subclasses of TcpTransportFactory to create custom instances of
088     * TcpTransport.
089     */
090    public static SslTransport createTransport(URI uri) throws Exception {
091        String protocol = protocol(uri.getScheme());
092        if( protocol !=null ) {
093            SslTransport rc = new SslTransport();
094            rc.setSSLContext(SSLContext.getInstance(protocol));
095            return rc;
096        }
097        return null;
098    }
099
100    public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
101
102        public int write(ByteBuffer plain) throws IOException {
103            return secure_write(plain);
104        }
105
106        public int read(ByteBuffer plain) throws IOException {
107            return secure_read(plain);
108        }
109
110        public boolean isOpen() {
111            return getSocketChannel().isOpen();
112        }
113
114        public void close() throws IOException {
115            getSocketChannel().close();
116        }
117
118        public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
119            if(offset+length > srcs.length || length<0 || offset<0) {
120                throw new IndexOutOfBoundsException();
121            }
122            long rc=0;
123            for (int i = 0; i < length; i++) {
124                ByteBuffer src = srcs[offset+i];
125                if(src.hasRemaining()) {
126                    rc += write(src);
127                }
128                if( src.hasRemaining() ) {
129                    return rc;
130                }
131            }
132            return rc;
133        }
134
135        public long write(ByteBuffer[] srcs) throws IOException {
136            return write(srcs, 0, srcs.length);
137        }
138
139        public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
140            if(offset+length > dsts.length || length<0 || offset<0) {
141                throw new IndexOutOfBoundsException();
142            }
143            long rc=0;
144            for (int i = 0; i < length; i++) {
145                ByteBuffer dst = dsts[offset+i];
146                if(dst.hasRemaining()) {
147                    rc += read(dst);
148                }
149                if( dst.hasRemaining() ) {
150                    return rc;
151                }
152            }
153            return rc;
154        }
155
156        public long read(ByteBuffer[] dsts) throws IOException {
157            return read(dsts, 0, dsts.length);
158        }
159        
160        public Socket socket() {
161            SocketChannel c = channel;
162            if( c == null ) {
163                return null;
164            }
165            return c.socket();
166        }
167    }
168
169    public SSLSession getSSLSession() {
170        return engine==null ? null : engine.getSession();
171    }
172
173    public X509Certificate[] getPeerX509Certificates() {
174        if( engine==null ) {
175            return null;
176        }
177        try {
178            ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
179            for( Certificate c:engine.getSession().getPeerCertificates() ) {
180                if(c instanceof X509Certificate) {
181                    rc.add((X509Certificate) c);
182                }
183            }
184            return rc.toArray(new X509Certificate[rc.size()]);
185        } catch (SSLPeerUnverifiedException e) {
186            return null;
187        }
188    }
189
190    @Override
191    public void connecting(URI remoteLocation, URI localLocation) throws Exception {
192        assert engine == null;
193        engine = sslContext.createSSLEngine(remoteLocation.getHost(), remoteLocation.getPort());
194        engine.setUseClientMode(true);
195        super.connecting(remoteLocation, localLocation);
196    }
197
198    @Override
199    public void connected(SocketChannel channel) throws Exception {
200        if (engine == null) {
201            engine = sslContext.createSSLEngine();
202            engine.setUseClientMode(false);
203            switch (clientAuth) {
204                case WANT: engine.setWantClientAuth(true); break;
205                case NEED: engine.setNeedClientAuth(true); break;
206                case NONE: engine.setWantClientAuth(false); break;
207            }
208
209        }
210
211        if (enabledCipherSuites != null) {
212            engine.setEnabledCipherSuites(splitOnCommas(enabledCipherSuites));
213        } else {
214            engine.setEnabledCipherSuites(engine.getSupportedCipherSuites());
215        }
216
217        if( disabledCypherSuites!=null ) {
218            String[] disabledList = splitOnCommas(disabledCypherSuites);
219            ArrayList<String> enabled = new ArrayList<String>();
220            for (String suite : engine.getEnabledCipherSuites()) {
221                boolean add = true;
222                for (String disabled : disabledList) {
223                    if( suite.contains(disabled) ) {
224                        add = false;
225                        break;
226                    }
227                }
228                if( add ) {
229                    enabled.add(suite);
230                }
231            }
232            engine.setEnabledCipherSuites(enabled.toArray(new String[enabled.size()]));
233        }
234
235        super.connected(channel);
236    }
237
238    private String[] splitOnCommas(String value) {
239        ArrayList<String> rc = new ArrayList<String>();
240        for( String x : value.split(",") ) {
241            rc.add(x.trim());
242        }
243        return rc.toArray(new String[rc.size()]);
244    }
245
246    @Override
247    protected void initializeChannel() throws Exception {
248        super.initializeChannel();
249        SSLSession session = engine.getSession();
250        readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
251        readBuffer.flip();
252        writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
253    }
254
255    @Override
256    protected void onConnected() throws IOException {
257        super.onConnected();
258        engine.beginHandshake();
259        handshake();
260    }
261
262    @Override
263    public void flush() {
264        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
265            handshake();
266        } else {
267            super.flush();
268        }
269    }
270
271    @Override
272    public void drainInbound() {
273        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
274            handshake();
275        } else {
276            super.drainInbound();
277        }
278    }
279
280    /**
281     * @return true if fully flushed.
282     * @throws IOException
283     */
284    protected boolean transportFlush() throws IOException {
285        while (true) {
286            if(writeFlushing) {
287                int count = super.getWriteChannel().write(writeBuffer);
288                if( !writeBuffer.hasRemaining() ) {
289                    writeBuffer.clear();
290                    writeFlushing = false;
291                    suspendWrite();
292                    return true;
293                } else {
294                    return false;
295                }
296            } else {
297                if( writeBuffer.position()!=0 ) {
298                    writeBuffer.flip();
299                    writeFlushing = true;
300                    resumeWrite();
301                } else {
302                    return true;
303                }
304            }
305        }
306    }
307
308    private int secure_write(ByteBuffer plain) throws IOException {
309        if( !transportFlush() ) {
310            // can't write anymore until the write_secured_buffer gets fully flushed out..
311            return 0;
312        }
313        int rc = 0;
314        while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
315            SSLEngineResult result = engine.wrap(plain, writeBuffer);
316            assert result.getStatus()!= BUFFER_OVERFLOW;
317            rc += result.bytesConsumed();
318            if( !transportFlush() || result.getStatus() == CLOSED) {
319                break;
320            }
321        }
322        if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
323            dispatchQueue.execute(new Task() {
324                public void run() {
325                    handshake();
326                }
327            });
328        }
329        return rc;
330    }
331
332    private int secure_read(ByteBuffer plain) throws IOException {
333        int rc=0;
334        while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
335            if( readOverflowBuffer !=null ) {
336                if(  plain.hasRemaining() ) {
337                    // lets drain the overflow buffer before trying to suck down anymore
338                    // network bytes.
339                    int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
340                    plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
341                    readOverflowBuffer.position(readOverflowBuffer.position()+size);
342                    if( !readOverflowBuffer.hasRemaining() ) {
343                        readOverflowBuffer = null;
344                    }
345                    rc += size;
346                } else {
347                    return rc;
348                }
349            } else if( readUnderflow ) {
350                int count = super.getReadChannel().read(readBuffer);
351                if( count == -1 ) {  // peer closed socket.
352                    if (rc==0) {
353                        return -1;
354                    } else {
355                        return rc;
356                    }
357                }
358                if( count==0 ) {  // no data available right now.
359                    return rc;
360                }
361                // read in some more data, perhaps now we can unwrap.
362                readUnderflow = false;
363                readBuffer.flip();
364            } else {
365                SSLEngineResult result = engine.unwrap(readBuffer, plain);
366                rc += result.bytesProduced();
367                if( result.getStatus() == BUFFER_OVERFLOW ) {
368                    readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
369                    result = engine.unwrap(readBuffer, readOverflowBuffer);
370                    if( readOverflowBuffer.position()==0 ) {
371                        readOverflowBuffer = null;
372                    } else {
373                        readOverflowBuffer.flip();
374                    }
375                }
376                switch( result.getStatus() ) {
377                    case CLOSED:
378                        if (rc==0) {
379                            engine.closeInbound();
380                            return -1;
381                        } else {
382                            return rc;
383                        }
384                    case OK:
385                        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
386                            dispatchQueue.execute(new Task() {
387                                public void run() {
388                                    handshake();
389                                }
390                            });
391                        }
392                        break;
393                    case BUFFER_UNDERFLOW:
394                        readBuffer.compact();
395                        readUnderflow = true;
396                        break;
397                    case BUFFER_OVERFLOW:
398                        throw new AssertionError("Unexpected case.");
399                }
400            }
401        }
402        return rc;
403    }
404
405    public void handshake() {
406        try {
407            if( !transportFlush() ) {
408                return;
409            }
410            switch (engine.getHandshakeStatus()) {
411                case NEED_TASK:
412                    final Runnable task = engine.getDelegatedTask();
413                    if( task!=null ) {
414                        blockingExecutor.execute(new Task() {
415                            public void run() {
416                                task.run();
417                                dispatchQueue.execute(new Task() {
418                                    public void run() {
419                                        if (isConnected()) {
420                                            handshake();
421                                        }
422                                    }
423                                });
424                            }
425                        });
426                    }
427                    break;
428
429                case NEED_WRAP:
430                    secure_write(ByteBuffer.allocate(0));
431                    break;
432
433                case NEED_UNWRAP:
434                    if( secure_read(ByteBuffer.allocate(0)) == -1) {
435                        throw new EOFException("Peer disconnected during ssl handshake");
436                    }
437                    break;
438
439                case FINISHED:
440                case NOT_HANDSHAKING:
441                    break;
442
443                default:
444                    System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
445                    break;
446            }
447        } catch (IOException e ) {
448            onTransportFailure(e);
449        } finally {
450            if( engine.getHandshakeStatus() == NOT_HANDSHAKING ) {
451                drainOutboundSource.merge(1);
452                super.drainInbound();
453            }
454        }
455    }
456
457
458    public ReadableByteChannel getReadChannel() {
459        return ssl_channel;
460    }
461
462    public WritableByteChannel getWriteChannel() {
463        return ssl_channel;
464    }
465
466    public String getClientAuth() {
467        return clientAuth.name();
468    }
469
470    public void setClientAuth(String clientAuth) {
471        this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase());
472    }
473
474    public String getDisabledCypherSuites() {
475        return disabledCypherSuites;
476    }
477
478    public String getEnabledCypherSuites() {
479        return enabledCipherSuites;
480    }
481
482    public void setDisabledCypherSuites(String disabledCypherSuites) {
483        this.disabledCypherSuites = disabledCypherSuites;
484    }
485    
486    public void setEnabledCypherSuites(String enabledCypherSuites) {
487        this.enabledCipherSuites = enabledCypherSuites;
488    }
489}
490
491