001/** 002 * Copyright (C) 2012 FuseSource, Inc. 003 * http://fusesource.com 004 * 005 * Licensed under the Apache License, Version 2.0 (the "License"); 006 * you may not use this file except in compliance with the License. 007 * You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.fusesource.hawtdispatch.transport; 019 020import org.fusesource.hawtdispatch.Task; 021 022import javax.net.ssl.*; 023import java.io.EOFException; 024import java.io.IOException; 025import java.nio.ByteBuffer; 026import java.nio.channels.GatheringByteChannel; 027import java.nio.channels.ReadableByteChannel; 028import java.nio.channels.ScatteringByteChannel; 029import java.nio.channels.WritableByteChannel; 030import java.security.cert.Certificate; 031import java.security.cert.X509Certificate; 032import java.util.ArrayList; 033 034import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*; 035import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW; 036 037/** 038 * Implements the SSL protocol as a WrappingProtocolCodec. Useful for when 039 * you want to switch to the SSL protocol on a regular TCP Transport. 040 */ 041public class SslProtocolCodec implements WrappingProtocolCodec, SecuredSession { 042 043 private ReadableByteChannel readChannel; 044 private WritableByteChannel writeChannel; 045 046 public enum ClientAuth { 047 WANT, NEED, NONE 048 }; 049 050 private SSLContext sslContext; 051 private SSLEngine engine; 052 053 private ByteBuffer readBuffer; 054 private boolean readUnderflow; 055 056 private ByteBuffer writeBuffer; 057 private boolean writeFlushing; 058 059 private ByteBuffer readOverflowBuffer; 060 Transport transport; 061 062 int lastReadSize; 063 int lastWriteSize; 064 long readCounter; 065 long writeCounter; 066 067 ProtocolCodec next; 068 069 070 public SslProtocolCodec() { 071 } 072 073 public ProtocolCodec getNext() { 074 return next; 075 } 076 public void setNext(ProtocolCodec next) { 077 this.next = next; 078 initNext(); 079 } 080 081 private void initNext() { 082 if( next!=null ) { 083 this.next.setTransport(new TransportFilter(transport){ 084 public ReadableByteChannel getReadChannel() { 085 return sslReadChannel; 086 } 087 public WritableByteChannel getWriteChannel() { 088 return sslWriteChannel; 089 } 090 }); 091 } 092 } 093 094 public void setSSLContext(SSLContext ctx) { 095 assert engine == null; 096 this.sslContext = ctx; 097 } 098 099 public SslProtocolCodec client() throws Exception { 100 initializeEngine(); 101 engine.setUseClientMode(true); 102 engine.beginHandshake(); 103 return this; 104 } 105 106 public SslProtocolCodec server(ClientAuth clientAuth) throws Exception { 107 initializeEngine(); 108 engine.setUseClientMode(false); 109 switch (clientAuth) { 110 case WANT: engine.setWantClientAuth(true); break; 111 case NEED: engine.setNeedClientAuth(true); break; 112 case NONE: engine.setWantClientAuth(false); break; 113 } 114 engine.beginHandshake(); 115 return this; 116 } 117 118 protected void initializeEngine() throws Exception { 119 assert engine == null; 120 if( sslContext == null ) { 121 sslContext = SSLContext.getDefault(); 122 } 123 engine = sslContext.createSSLEngine(); 124 SSLSession session = engine.getSession(); 125 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 126 readBuffer.flip(); 127 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 128 } 129 130 131 public SSLSession getSSLSession() { 132 return engine==null ? null : engine.getSession(); 133 } 134 135 public X509Certificate[] getPeerX509Certificates() { 136 if( engine==null ) { 137 return null; 138 } 139 try { 140 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>(); 141 for( Certificate c:engine.getSession().getPeerCertificates() ) { 142 if(c instanceof X509Certificate) { 143 rc.add((X509Certificate) c); 144 } 145 } 146 return rc.toArray(new X509Certificate[rc.size()]); 147 } catch (SSLPeerUnverifiedException e) { 148 return null; 149 } 150 } 151 152 SSLReadChannel sslReadChannel = new SSLReadChannel(); 153 SSLWriteChannel sslWriteChannel = new SSLWriteChannel(); 154 155 public void setTransport(Transport transport) { 156 this.transport = transport; 157 this.readChannel = transport.getReadChannel(); 158 this.writeChannel = transport.getWriteChannel(); 159 initNext(); 160 } 161 162 public void handshake() throws IOException { 163 if( !transportFlush() ) { 164 return; 165 } 166 switch (engine.getHandshakeStatus()) { 167 case NEED_TASK: 168 final Runnable task = engine.getDelegatedTask(); 169 if( task!=null ) { 170 transport.getBlockingExecutor().execute(new Task() { 171 public void run() { 172 task.run(); 173 transport.getDispatchQueue().execute(new Task() { 174 public void run() { 175 if (readChannel.isOpen() && writeChannel.isOpen()) { 176 try { 177 handshake(); 178 } catch (IOException e) { 179 transport.getTransportListener().onTransportFailure(e); 180 } 181 } 182 } 183 }); 184 } 185 }); 186 } 187 break; 188 189 case NEED_WRAP: 190 secure_write(ByteBuffer.allocate(0)); 191 break; 192 193 case NEED_UNWRAP: 194 if( secure_read(ByteBuffer.allocate(0)) == -1) { 195 throw new EOFException("Peer disconnected during ssl handshake"); 196 } 197 break; 198 199 case FINISHED: 200 case NOT_HANDSHAKING: 201 transport.drainInbound(); 202 transport.getTransportListener().onRefill(); 203 break; 204 205 default: 206 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus()); 207 break; 208 } 209 } 210 211 /** 212 * @return true if fully flushed. 213 * @throws IOException 214 */ 215 protected boolean transportFlush() throws IOException { 216 while (true) { 217 if(writeFlushing) { 218 lastWriteSize = writeChannel.write(writeBuffer); 219 if( lastWriteSize > 0 ) { 220 writeCounter += lastWriteSize; 221 } 222 if( !writeBuffer.hasRemaining() ) { 223 writeBuffer.clear(); 224 writeFlushing = false; 225 return true; 226 } else { 227 return false; 228 } 229 } else { 230 if( writeBuffer.position()!=0 ) { 231 writeBuffer.flip(); 232 writeFlushing = true; 233 } else { 234 return true; 235 } 236 } 237 } 238 } 239 240 private int secure_read(ByteBuffer plain) throws IOException { 241 int rc=0; 242 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) { 243 if( readOverflowBuffer !=null ) { 244 if( plain.hasRemaining() ) { 245 // lets drain the overflow buffer before trying to suck down anymore 246 // network bytes. 247 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining()); 248 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size); 249 readOverflowBuffer.position(readOverflowBuffer.position()+size); 250 if( !readOverflowBuffer.hasRemaining() ) { 251 readOverflowBuffer = null; 252 } 253 rc += size; 254 } else { 255 return rc; 256 } 257 } else if( readUnderflow ) { 258 lastReadSize = readChannel.read(readBuffer); 259 if( lastReadSize == -1 ) { // peer closed socket. 260 if (rc==0) { 261 return -1; 262 } else { 263 return rc; 264 } 265 } 266 if( lastReadSize==0 ) { // no data available right now. 267 return rc; 268 } 269 readCounter += lastReadSize; 270 // read in some more data, perhaps now we can unwrap. 271 readUnderflow = false; 272 readBuffer.flip(); 273 } else { 274 SSLEngineResult result = engine.unwrap(readBuffer, plain); 275 rc += result.bytesProduced(); 276 if( result.getStatus() == BUFFER_OVERFLOW ) { 277 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()); 278 result = engine.unwrap(readBuffer, readOverflowBuffer); 279 if( readOverflowBuffer.position()==0 ) { 280 readOverflowBuffer = null; 281 } else { 282 readOverflowBuffer.flip(); 283 } 284 } 285 switch( result.getStatus() ) { 286 case CLOSED: 287 if (rc==0) { 288 engine.closeInbound(); 289 return -1; 290 } else { 291 return rc; 292 } 293 case OK: 294 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 295 handshake(); 296 } 297 break; 298 case BUFFER_UNDERFLOW: 299 readBuffer.compact(); 300 readUnderflow = true; 301 break; 302 case BUFFER_OVERFLOW: 303 throw new AssertionError("Unexpected case."); 304 } 305 } 306 } 307 return rc; 308 } 309 310 private int secure_write(ByteBuffer plain) throws IOException { 311 if( !transportFlush() ) { 312 // can't write anymore until the write_secured_buffer gets fully flushed out.. 313 return 0; 314 } 315 int rc = 0; 316 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) { 317 SSLEngineResult result = engine.wrap(plain, writeBuffer); 318 assert result.getStatus()!= BUFFER_OVERFLOW; 319 rc += result.bytesConsumed(); 320 if( !transportFlush() ) { 321 break; 322 } 323 } 324 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 325 handshake(); 326 } 327 return rc; 328 } 329 330 public class SSLReadChannel implements ScatteringByteChannel { 331 332 public int read(ByteBuffer plain) throws IOException { 333 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 334 handshake(); 335 } 336 return secure_read(plain); 337 } 338 339 public boolean isOpen() { 340 return readChannel.isOpen(); 341 } 342 343 public void close() throws IOException { 344 readChannel.close(); 345 } 346 347 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { 348 if(offset+length > dsts.length || length<0 || offset<0) { 349 throw new IndexOutOfBoundsException(); 350 } 351 long rc=0; 352 for (int i = 0; i < length; i++) { 353 ByteBuffer dst = dsts[offset+i]; 354 if(dst.hasRemaining()) { 355 rc += read(dst); 356 } 357 if( dst.hasRemaining() ) { 358 return rc; 359 } 360 } 361 return rc; 362 } 363 364 public long read(ByteBuffer[] dsts) throws IOException { 365 return read(dsts, 0, dsts.length); 366 } 367 } 368 369 public class SSLWriteChannel implements GatheringByteChannel { 370 371 public int write(ByteBuffer plain) throws IOException { 372 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 373 handshake(); 374 } 375 return secure_write(plain); 376 } 377 378 public boolean isOpen() { 379 return writeChannel.isOpen(); 380 } 381 382 public void close() throws IOException { 383 writeChannel.close(); 384 } 385 386 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { 387 if(offset+length > srcs.length || length<0 || offset<0) { 388 throw new IndexOutOfBoundsException(); 389 } 390 long rc=0; 391 for (int i = 0; i < length; i++) { 392 ByteBuffer src = srcs[offset+i]; 393 if(src.hasRemaining()) { 394 rc += write(src); 395 } 396 if( src.hasRemaining() ) { 397 return rc; 398 } 399 } 400 return rc; 401 } 402 403 public long write(ByteBuffer[] srcs) throws IOException { 404 return write(srcs, 0, srcs.length); 405 } 406 } 407 408 public void unread(byte[] buffer) { 409 readBuffer.compact(); 410 if( readBuffer.remaining() < buffer.length) { 411 throw new IllegalStateException("Cannot unread now"); 412 } 413 readBuffer.put(buffer); 414 readBuffer.flip(); 415 } 416 417 public Object read() throws IOException { 418 return next.read(); 419 } 420 421 public ProtocolCodec.BufferState write(Object value) throws IOException { 422 return next.write(value); 423 } 424 425 public ProtocolCodec.BufferState flush() throws IOException { 426 return next.flush(); 427 } 428 429 public boolean full() { 430 return next.full(); 431 } 432 433 public long getWriteCounter() { 434 return writeCounter; 435 } 436 437 public long getLastWriteSize() { 438 return lastWriteSize; 439 } 440 441 public long getReadCounter() { 442 return readCounter; 443 } 444 445 public long getLastReadSize() { 446 return lastReadSize; 447 } 448 449 public int getReadBufferSize() { 450 return readBuffer.capacity(); 451 } 452 453 public int getWriteBufferSize() { 454 return writeBuffer.capacity(); 455 } 456 457 458 459}