001/** 002 * Copyright (C) 2012 FuseSource, Inc. 003 * http://fusesource.com 004 * 005 * Licensed under the Apache License, Version 2.0 (the "License"); 006 * you may not use this file except in compliance with the License. 007 * You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.fusesource.hawtdispatch.transport; 019 020import org.fusesource.hawtdispatch.Task; 021 022import javax.net.ssl.*; 023import java.io.EOFException; 024import java.io.IOException; 025import java.net.Socket; 026import java.net.URI; 027import java.nio.ByteBuffer; 028import java.nio.channels.*; 029import java.security.cert.Certificate; 030import java.security.cert.X509Certificate; 031import java.util.ArrayList; 032 033import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*; 034import static javax.net.ssl.SSLEngineResult.Status.*; 035 036/** 037 * An SSL Transport for secure communications. 038 * 039 * @author <a href="http://hiramchirino.com">Hiram Chirino</a> 040 */ 041public class SslTransport extends TcpTransport implements SecuredSession { 042 043 /** 044 * Maps uri schemes to a protocol algorithm names. 045 * Valid algorithm names listed at: 046 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext 047 */ 048 public static String protocol(String scheme) { 049 if( scheme.equals("tls") ) { 050 return "TLS"; 051 } else if( scheme.startsWith("tlsv") ) { 052 return "TLSv"+scheme.substring(4); 053 } else if( scheme.equals("ssl") ) { 054 return "SSL"; 055 } else if( scheme.startsWith("sslv") ) { 056 return "SSLv"+scheme.substring(4); 057 } 058 return null; 059 } 060 061 enum ClientAuth { 062 WANT, NEED, NONE 063 }; 064 065 private ClientAuth clientAuth = ClientAuth.WANT; 066 private String disabledCypherSuites = null; 067 private String enabledCipherSuites = null; 068 069 private SSLContext sslContext; 070 private SSLEngine engine; 071 072 private ByteBuffer readBuffer; 073 private boolean readUnderflow; 074 075 private ByteBuffer writeBuffer; 076 private boolean writeFlushing; 077 078 private ByteBuffer readOverflowBuffer; 079 private SSLChannel ssl_channel = new SSLChannel(); 080 081 082 public void setSSLContext(SSLContext ctx) { 083 this.sslContext = ctx; 084 } 085 086 /** 087 * Allows subclasses of TcpTransportFactory to create custom instances of 088 * TcpTransport. 089 */ 090 public static SslTransport createTransport(URI uri) throws Exception { 091 String protocol = protocol(uri.getScheme()); 092 if( protocol !=null ) { 093 SslTransport rc = new SslTransport(); 094 rc.setSSLContext(SSLContext.getInstance(protocol)); 095 return rc; 096 } 097 return null; 098 } 099 100 public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel { 101 102 public int write(ByteBuffer plain) throws IOException { 103 return secure_write(plain); 104 } 105 106 public int read(ByteBuffer plain) throws IOException { 107 return secure_read(plain); 108 } 109 110 public boolean isOpen() { 111 return getSocketChannel().isOpen(); 112 } 113 114 public void close() throws IOException { 115 getSocketChannel().close(); 116 } 117 118 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { 119 if(offset+length > srcs.length || length<0 || offset<0) { 120 throw new IndexOutOfBoundsException(); 121 } 122 long rc=0; 123 for (int i = 0; i < length; i++) { 124 ByteBuffer src = srcs[offset+i]; 125 if(src.hasRemaining()) { 126 rc += write(src); 127 } 128 if( src.hasRemaining() ) { 129 return rc; 130 } 131 } 132 return rc; 133 } 134 135 public long write(ByteBuffer[] srcs) throws IOException { 136 return write(srcs, 0, srcs.length); 137 } 138 139 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { 140 if(offset+length > dsts.length || length<0 || offset<0) { 141 throw new IndexOutOfBoundsException(); 142 } 143 long rc=0; 144 for (int i = 0; i < length; i++) { 145 ByteBuffer dst = dsts[offset+i]; 146 if(dst.hasRemaining()) { 147 rc += read(dst); 148 } 149 if( dst.hasRemaining() ) { 150 return rc; 151 } 152 } 153 return rc; 154 } 155 156 public long read(ByteBuffer[] dsts) throws IOException { 157 return read(dsts, 0, dsts.length); 158 } 159 160 public Socket socket() { 161 SocketChannel c = channel; 162 if( c == null ) { 163 return null; 164 } 165 return c.socket(); 166 } 167 } 168 169 public SSLSession getSSLSession() { 170 return engine==null ? null : engine.getSession(); 171 } 172 173 public X509Certificate[] getPeerX509Certificates() { 174 if( engine==null ) { 175 return null; 176 } 177 try { 178 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>(); 179 for( Certificate c:engine.getSession().getPeerCertificates() ) { 180 if(c instanceof X509Certificate) { 181 rc.add((X509Certificate) c); 182 } 183 } 184 return rc.toArray(new X509Certificate[rc.size()]); 185 } catch (SSLPeerUnverifiedException e) { 186 return null; 187 } 188 } 189 190 @Override 191 public void connecting(URI remoteLocation, URI localLocation) throws Exception { 192 assert engine == null; 193 engine = sslContext.createSSLEngine(remoteLocation.getHost(), remoteLocation.getPort()); 194 engine.setUseClientMode(true); 195 super.connecting(remoteLocation, localLocation); 196 } 197 198 @Override 199 public void connected(SocketChannel channel) throws Exception { 200 if (engine == null) { 201 engine = sslContext.createSSLEngine(); 202 engine.setUseClientMode(false); 203 switch (clientAuth) { 204 case WANT: engine.setWantClientAuth(true); break; 205 case NEED: engine.setNeedClientAuth(true); break; 206 case NONE: engine.setWantClientAuth(false); break; 207 } 208 209 } 210 211 if (enabledCipherSuites != null) { 212 engine.setEnabledCipherSuites(splitOnCommas(enabledCipherSuites)); 213 } else { 214 engine.setEnabledCipherSuites(engine.getSupportedCipherSuites()); 215 } 216 217 if( disabledCypherSuites!=null ) { 218 String[] disabledList = splitOnCommas(disabledCypherSuites); 219 ArrayList<String> enabled = new ArrayList<String>(); 220 for (String suite : engine.getEnabledCipherSuites()) { 221 boolean add = true; 222 for (String disabled : disabledList) { 223 if( suite.contains(disabled) ) { 224 add = false; 225 break; 226 } 227 } 228 if( add ) { 229 enabled.add(suite); 230 } 231 } 232 engine.setEnabledCipherSuites(enabled.toArray(new String[enabled.size()])); 233 } 234 235 super.connected(channel); 236 } 237 238 private String[] splitOnCommas(String value) { 239 ArrayList<String> rc = new ArrayList<String>(); 240 for( String x : value.split(",") ) { 241 rc.add(x.trim()); 242 } 243 return rc.toArray(new String[rc.size()]); 244 } 245 246 @Override 247 protected void initializeChannel() throws Exception { 248 super.initializeChannel(); 249 SSLSession session = engine.getSession(); 250 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 251 readBuffer.flip(); 252 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 253 } 254 255 @Override 256 protected void onConnected() throws IOException { 257 super.onConnected(); 258 engine.beginHandshake(); 259 handshake(); 260 } 261 262 @Override 263 public void flush() { 264 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 265 handshake(); 266 } else { 267 super.flush(); 268 } 269 } 270 271 @Override 272 public void drainInbound() { 273 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 274 handshake(); 275 } else { 276 super.drainInbound(); 277 } 278 } 279 280 /** 281 * @return true if fully flushed. 282 * @throws IOException 283 */ 284 protected boolean transportFlush() throws IOException { 285 while (true) { 286 if(writeFlushing) { 287 int count = super.getWriteChannel().write(writeBuffer); 288 if( !writeBuffer.hasRemaining() ) { 289 writeBuffer.clear(); 290 writeFlushing = false; 291 suspendWrite(); 292 return true; 293 } else { 294 return false; 295 } 296 } else { 297 if( writeBuffer.position()!=0 ) { 298 writeBuffer.flip(); 299 writeFlushing = true; 300 resumeWrite(); 301 } else { 302 return true; 303 } 304 } 305 } 306 } 307 308 private int secure_write(ByteBuffer plain) throws IOException { 309 if( !transportFlush() ) { 310 // can't write anymore until the write_secured_buffer gets fully flushed out.. 311 return 0; 312 } 313 int rc = 0; 314 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) { 315 SSLEngineResult result = engine.wrap(plain, writeBuffer); 316 assert result.getStatus()!= BUFFER_OVERFLOW; 317 rc += result.bytesConsumed(); 318 if( !transportFlush() || result.getStatus() == CLOSED) { 319 break; 320 } 321 } 322 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 323 dispatchQueue.execute(new Task() { 324 public void run() { 325 handshake(); 326 } 327 }); 328 } 329 return rc; 330 } 331 332 private int secure_read(ByteBuffer plain) throws IOException { 333 int rc=0; 334 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) { 335 if( readOverflowBuffer !=null ) { 336 if( plain.hasRemaining() ) { 337 // lets drain the overflow buffer before trying to suck down anymore 338 // network bytes. 339 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining()); 340 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size); 341 readOverflowBuffer.position(readOverflowBuffer.position()+size); 342 if( !readOverflowBuffer.hasRemaining() ) { 343 readOverflowBuffer = null; 344 } 345 rc += size; 346 } else { 347 return rc; 348 } 349 } else if( readUnderflow ) { 350 int count = super.getReadChannel().read(readBuffer); 351 if( count == -1 ) { // peer closed socket. 352 if (rc==0) { 353 return -1; 354 } else { 355 return rc; 356 } 357 } 358 if( count==0 ) { // no data available right now. 359 return rc; 360 } 361 // read in some more data, perhaps now we can unwrap. 362 readUnderflow = false; 363 readBuffer.flip(); 364 } else { 365 SSLEngineResult result = engine.unwrap(readBuffer, plain); 366 rc += result.bytesProduced(); 367 if( result.getStatus() == BUFFER_OVERFLOW ) { 368 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()); 369 result = engine.unwrap(readBuffer, readOverflowBuffer); 370 if( readOverflowBuffer.position()==0 ) { 371 readOverflowBuffer = null; 372 } else { 373 readOverflowBuffer.flip(); 374 } 375 } 376 switch( result.getStatus() ) { 377 case CLOSED: 378 if (rc==0) { 379 engine.closeInbound(); 380 return -1; 381 } else { 382 return rc; 383 } 384 case OK: 385 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 386 dispatchQueue.execute(new Task() { 387 public void run() { 388 handshake(); 389 } 390 }); 391 } 392 break; 393 case BUFFER_UNDERFLOW: 394 readBuffer.compact(); 395 readUnderflow = true; 396 break; 397 case BUFFER_OVERFLOW: 398 throw new AssertionError("Unexpected case."); 399 } 400 } 401 } 402 return rc; 403 } 404 405 public void handshake() { 406 try { 407 if( !transportFlush() ) { 408 return; 409 } 410 switch (engine.getHandshakeStatus()) { 411 case NEED_TASK: 412 final Runnable task = engine.getDelegatedTask(); 413 if( task!=null ) { 414 blockingExecutor.execute(new Task() { 415 public void run() { 416 task.run(); 417 dispatchQueue.execute(new Task() { 418 public void run() { 419 if (isConnected()) { 420 handshake(); 421 } 422 } 423 }); 424 } 425 }); 426 } 427 break; 428 429 case NEED_WRAP: 430 secure_write(ByteBuffer.allocate(0)); 431 break; 432 433 case NEED_UNWRAP: 434 if( secure_read(ByteBuffer.allocate(0)) == -1) { 435 throw new EOFException("Peer disconnected during ssl handshake"); 436 } 437 break; 438 439 case FINISHED: 440 case NOT_HANDSHAKING: 441 break; 442 443 default: 444 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus()); 445 break; 446 } 447 } catch (IOException e ) { 448 onTransportFailure(e); 449 } finally { 450 if( engine.getHandshakeStatus() == NOT_HANDSHAKING ) { 451 drainOutboundSource.merge(1); 452 super.drainInbound(); 453 } 454 } 455 } 456 457 458 public ReadableByteChannel getReadChannel() { 459 return ssl_channel; 460 } 461 462 public WritableByteChannel getWriteChannel() { 463 return ssl_channel; 464 } 465 466 public String getClientAuth() { 467 return clientAuth.name(); 468 } 469 470 public void setClientAuth(String clientAuth) { 471 this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase()); 472 } 473 474 public String getDisabledCypherSuites() { 475 return disabledCypherSuites; 476 } 477 478 public String getEnabledCypherSuites() { 479 return enabledCipherSuites; 480 } 481 482 public void setDisabledCypherSuites(String disabledCypherSuites) { 483 this.disabledCypherSuites = disabledCypherSuites; 484 } 485 486 public void setEnabledCypherSuites(String enabledCypherSuites) { 487 this.enabledCipherSuites = enabledCypherSuites; 488 } 489} 490 491