001/**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *    http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.fusesource.hawtdispatch.transport;
018
019import org.fusesource.hawtdispatch.*;
020
021import java.io.IOException;
022import java.net.*;
023import java.nio.channels.DatagramChannel;
024import java.nio.channels.ReadableByteChannel;
025import java.nio.channels.SelectionKey;
026import java.nio.channels.WritableByteChannel;
027import java.util.LinkedList;
028import java.util.concurrent.Executor;
029
030/**
031 * <p>
032 * </p>
033 *
034 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
035 */
036public class UdpTransport extends ServiceBase implements Transport {
037
038    public static final SocketAddress ANY_ADDRESS = new SocketAddress() {
039        @Override
040        public String toString() {
041            return "*:*";
042        }
043    };
044
045
046    abstract static class SocketState {
047        void onStop(Task onCompleted) {
048        }
049        void onCanceled() {
050        }
051        boolean is(Class<? extends SocketState> clazz) {
052            return getClass()==clazz;
053        }
054    }
055
056    static class DISCONNECTED extends SocketState{}
057
058    class CONNECTING extends SocketState{
059        void onStop(Task onCompleted) {
060            trace("CONNECTING.onStop");
061            CANCELING state = new CANCELING();
062            socketState = state;
063            state.onStop(onCompleted);
064        }
065        void onCanceled() {
066            trace("CONNECTING.onCanceled");
067            CANCELING state = new CANCELING();
068            socketState = state;
069            state.onCanceled();
070        }
071    }
072
073    class CONNECTED extends SocketState {
074        public CONNECTED() {
075            localAddress = channel.socket().getLocalSocketAddress();
076            remoteAddress = channel.socket().getRemoteSocketAddress();
077            if(remoteAddress == null ) {
078                remoteAddress = ANY_ADDRESS;
079            }
080        }
081
082        void onStop(Task onCompleted) {
083            trace("CONNECTED.onStop");
084            CANCELING state = new CANCELING();
085            socketState = state;
086            state.add(createDisconnectTask());
087            state.onStop(onCompleted);
088        }
089        void onCanceled() {
090            trace("CONNECTED.onCanceled");
091            CANCELING state = new CANCELING();
092            socketState = state;
093            state.add(createDisconnectTask());
094            state.onCanceled();
095        }
096        Task createDisconnectTask() {
097            return new Task(){
098                public void run() {
099                    listener.onTransportDisconnected();
100                }
101            };
102        }
103    }
104
105    class CANCELING extends SocketState {
106        private LinkedList<Task> runnables =  new LinkedList<Task>();
107        private int remaining;
108        private boolean dispose;
109
110        public CANCELING() {
111            if( readSource!=null ) {
112                remaining++;
113                readSource.cancel();
114            }
115            if( writeSource!=null ) {
116                remaining++;
117                writeSource.cancel();
118            }
119        }
120        void onStop(Task onCompleted) {
121            trace("CANCELING.onCompleted");
122            add(onCompleted);
123            dispose = true;
124        }
125        void add(Task onCompleted) {
126            if( onCompleted!=null ) {
127                runnables.add(onCompleted);
128            }
129        }
130        void onCanceled() {
131            trace("CANCELING.onCanceled");
132            remaining--;
133            if( remaining!=0 ) {
134                return;
135            }
136            try {
137                channel.close();
138            } catch (IOException ignore) {
139            }
140            socketState = new CANCELED(dispose);
141            for (Task runnable : runnables) {
142                runnable.run();
143            }
144            if (dispose) {
145                dispose();
146            }
147        }
148    }
149
150    class CANCELED extends SocketState {
151        private boolean disposed;
152
153        public CANCELED(boolean disposed) {
154            this.disposed=disposed;
155        }
156
157        void onStop(Task onCompleted) {
158            trace("CANCELED.onStop");
159            if( !disposed ) {
160                disposed = true;
161                dispose();
162            }
163            onCompleted.run();
164        }
165    }
166
167    protected URI remoteLocation;
168    protected URI localLocation;
169    protected TransportListener listener;
170    protected ProtocolCodec codec;
171
172    protected DatagramChannel channel;
173
174    protected SocketState socketState = new DISCONNECTED();
175
176    protected DispatchQueue dispatchQueue;
177    private DispatchSource readSource;
178    private DispatchSource writeSource;
179    protected CustomDispatchSource<Integer, Integer> drainOutboundSource;
180    protected CustomDispatchSource<Integer, Integer> yieldSource;
181
182    protected boolean useLocalHost = true;
183
184    int receiveBufferSize = 1024*64;
185    int sendBufferSize = 1024*64;
186
187
188    public static final int IPTOS_LOWCOST = 0x02;
189    public static final int IPTOS_RELIABILITY = 0x04;
190    public static final int IPTOS_THROUGHPUT = 0x08;
191    public static final int IPTOS_LOWDELAY = 0x10;
192
193    int trafficClass = IPTOS_THROUGHPUT;
194
195    SocketAddress localAddress;
196    SocketAddress remoteAddress = ANY_ADDRESS;
197    Executor blockingExecutor;
198
199    private final Task CANCEL_HANDLER = new Task() {
200        public void run() {
201            socketState.onCanceled();
202        }
203    };
204
205    static final class OneWay {
206        final Object command;
207        final Retained retained;
208
209        public OneWay(Object command, Retained retained) {
210            this.command = command;
211            this.retained = retained;
212        }
213    }
214
215    public void connected(DatagramChannel channel) throws IOException, Exception {
216        this.channel = channel;
217        initializeChannel();
218        this.socketState = new CONNECTED();
219    }
220
221    protected void initializeChannel() throws Exception {
222        this.channel.configureBlocking(false);
223        DatagramSocket socket = channel.socket();
224        try {
225            socket.setReuseAddress(true);
226        } catch (SocketException e) {
227        }
228        try {
229            socket.setTrafficClass(trafficClass);
230        } catch (SocketException e) {
231        }
232        try {
233            socket.setReceiveBufferSize(receiveBufferSize);
234        } catch (SocketException e) {
235        }
236        try {
237            socket.setSendBufferSize(sendBufferSize);
238        } catch (SocketException e) {
239        }
240        if( channel!=null && codec!=null ) {
241            initializeCodec();
242        }
243    }
244
245    protected void initializeCodec() throws Exception {
246        codec.setTransport(this);
247    }
248
249    public void connecting(final URI remoteLocation, final URI localLocation) throws Exception {
250        this.channel = DatagramChannel.open();
251        initializeChannel();
252        this.remoteLocation = remoteLocation;
253        this.localLocation = localLocation;
254        socketState = new CONNECTING();
255    }
256
257
258    public DispatchQueue getDispatchQueue() {
259        return dispatchQueue;
260    }
261
262    public void setDispatchQueue(DispatchQueue queue) {
263        this.dispatchQueue = queue;
264        if(readSource!=null) readSource.setTargetQueue(queue);
265        if(writeSource!=null) writeSource.setTargetQueue(queue);
266        if(drainOutboundSource!=null) drainOutboundSource.setTargetQueue(queue);
267        if(yieldSource!=null) yieldSource.setTargetQueue(queue);
268    }
269
270    public void _start(Task onCompleted) {
271        try {
272            if ( socketState.is(CONNECTING.class) ) {
273                // Resolving host names might block.. so do it on the blocking executor.
274                this.blockingExecutor.execute(new Runnable() {
275                    public void run() {
276                        // No need to complete if we have been canceled.
277                        if( ! socketState.is(CONNECTING.class) ) {
278                            return;
279                        }
280                        try {
281
282                            final InetSocketAddress localAddress = (localLocation != null) ?
283                                 new InetSocketAddress(InetAddress.getByName(localLocation.getHost()), localLocation.getPort())
284                                 : null;
285
286                            String host = resolveHostName(remoteLocation.getHost());
287                            final InetSocketAddress remoteAddress = new InetSocketAddress(host, remoteLocation.getPort());
288
289                            // Done resolving.. switch back to the dispatch queue.
290                            dispatchQueue.execute(new Task() {
291                                @Override
292                                public void run() {
293                                    try {
294                                        if(localAddress!=null) {
295                                            channel.socket().bind(localAddress);
296                                        }
297                                        channel.connect(remoteAddress);
298                                    } catch (IOException e) {
299                                        try {
300                                            channel.close();
301                                        } catch (IOException ignore) {
302                                        }
303                                        socketState = new CANCELED(true);
304                                        listener.onTransportFailure(e);
305                                    }
306                                }
307                            });
308
309                        } catch (IOException e) {
310                            try {
311                                channel.close();
312                            } catch (IOException ignore) {
313                            }
314                            socketState = new CANCELED(true);
315                            listener.onTransportFailure(e);
316                        }
317                    }
318                });
319
320            } else if (socketState.is(CONNECTED.class) ) {
321                dispatchQueue.execute(new Task() {
322                    public void run() {
323                        try {
324                            trace("was connected.");
325                            onConnected();
326                        } catch (IOException e) {
327                             onTransportFailure(e);
328                        }
329                    }
330                });
331            } else {
332                System.err.println("cannot be started.  socket state is: "+socketState);
333            }
334        } finally {
335            if( onCompleted!=null ) {
336                onCompleted.run();
337            }
338        }
339    }
340
341    public void _stop(final Task onCompleted) {
342        trace("stopping.. at state: "+socketState);
343        socketState.onStop(onCompleted);
344    }
345
346    protected String resolveHostName(String host) throws UnknownHostException {
347        String localName = InetAddress.getLocalHost().getHostName();
348        if (localName != null && isUseLocalHost()) {
349            if (localName.equals(host)) {
350                return "localhost";
351            }
352        }
353        return host;
354    }
355
356    protected void onConnected() throws IOException {
357        yieldSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
358        yieldSource.setEventHandler(new Task() {
359            public void run() {
360                drainInbound();
361            }
362        });
363        yieldSource.resume();
364        drainOutboundSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
365        drainOutboundSource.setEventHandler(new Task() {
366            public void run() {
367                flush();
368            }
369        });
370        drainOutboundSource.resume();
371
372        readSource = Dispatch.createSource(channel, SelectionKey.OP_READ, dispatchQueue);
373        writeSource = Dispatch.createSource(channel, SelectionKey.OP_WRITE, dispatchQueue);
374
375        readSource.setCancelHandler(CANCEL_HANDLER);
376        writeSource.setCancelHandler(CANCEL_HANDLER);
377
378        readSource.setEventHandler(new Task() {
379            public void run() {
380                drainInbound();
381            }
382        });
383        writeSource.setEventHandler(new Task() {
384            public void run() {
385                flush();
386            }
387        });
388        listener.onTransportConnected();
389    }
390
391    Task onDispose;
392
393    private void dispose() {
394        if( readSource!=null ) {
395            readSource.cancel();
396            readSource=null;
397        }
398
399        if( writeSource!=null ) {
400            writeSource.cancel();
401            writeSource=null;
402        }
403        this.codec = null;
404        if(onDispose!=null) {
405            onDispose.run();
406            onDispose = null;
407        }
408    }
409
410    public void onTransportFailure(IOException error) {
411        listener.onTransportFailure(error);
412        socketState.onCanceled();
413    }
414
415
416    public boolean full() {
417        return codec==null || codec.full();
418    }
419
420    boolean rejectingOffers;
421
422    public boolean offer(Object command) {
423        dispatchQueue.assertExecuting();
424        try {
425            if (!socketState.is(CONNECTED.class)) {
426                throw new IOException("Not connected.");
427            }
428            if (getServiceState() != STARTED) {
429                throw new IOException("Not running.");
430            }
431
432            ProtocolCodec.BufferState rc = codec.write(command);
433            rejectingOffers = codec.full();
434            switch (rc ) {
435                case FULL:
436                    return false;
437                default:
438                    drainOutboundSource.merge(1);
439                    return true;
440            }
441        } catch (IOException e) {
442            onTransportFailure(e);
443            return false;
444        }
445
446    }
447
448    boolean writeResumedForCodecFlush = false;
449
450    /**
451     *
452     */
453    public void flush() {
454        dispatchQueue.assertExecuting();
455        if (getServiceState() != STARTED || !socketState.is(CONNECTED.class)) {
456            return;
457        }
458        try {
459            if( codec.flush() == ProtocolCodec.BufferState.EMPTY && transportFlush() ) {
460                if( writeResumedForCodecFlush) {
461                    writeResumedForCodecFlush = false;
462                    suspendWrite();
463                }
464                rejectingOffers = false;
465                listener.onRefill();
466
467            } else {
468                if(!writeResumedForCodecFlush) {
469                    writeResumedForCodecFlush = true;
470                    resumeWrite();
471                }
472            }
473        } catch (IOException e) {
474            onTransportFailure(e);
475        }
476    }
477
478    protected boolean transportFlush() throws IOException {
479        return true;
480    }
481
482    public void drainInbound() {
483        if (!getServiceState().isStarted() || readSource.isSuspended()) {
484            return;
485        }
486        try {
487            long initial = codec.getReadCounter();
488            // Only process upto 2 x the read buffer worth of data at a time so we can give
489            // other connections a chance to process their requests.
490            while( codec.getReadCounter()-initial < codec.getReadBufferSize()<<2 ) {
491                Object command = codec.read();
492                if ( command!=null ) {
493                    try {
494                        listener.onTransportCommand(command);
495                    } catch (Throwable e) {
496                        e.printStackTrace();
497                        onTransportFailure(new IOException("Transport listener failure."));
498                    }
499
500                    // the transport may be suspended after processing a command.
501                    if (getServiceState() == STOPPED || readSource.isSuspended()) {
502                        return;
503                    }
504                } else {
505                    return;
506                }
507            }
508            yieldSource.merge(1);
509        } catch (IOException e) {
510            onTransportFailure(e);
511        }
512    }
513
514    public SocketAddress getLocalAddress() {
515        return localAddress;
516    }
517
518    public SocketAddress getRemoteAddress() {
519        return remoteAddress;
520    }
521
522    private boolean assertConnected() {
523        try {
524            if ( !isConnected() ) {
525                throw new IOException("Not connected.");
526            }
527            return true;
528        } catch (IOException e) {
529            onTransportFailure(e);
530        }
531        return false;
532    }
533
534    public void suspendRead() {
535        if( isConnected() && readSource!=null ) {
536            readSource.suspend();
537        }
538    }
539
540
541    public void resumeRead() {
542        if( isConnected() && readSource!=null ) {
543            _resumeRead();
544        }
545    }
546
547    private void _resumeRead() {
548        readSource.resume();
549        dispatchQueue.execute(new Task(){
550            public void run() {
551                drainInbound();
552            }
553        });
554    }
555
556    protected void suspendWrite() {
557        if( isConnected() && writeSource!=null ) {
558            writeSource.suspend();
559        }
560    }
561
562    protected void resumeWrite() {
563        if( isConnected() && writeSource!=null ) {
564            writeSource.resume();
565        }
566    }
567
568    public TransportListener getTransportListener() {
569        return listener;
570    }
571
572    public void setTransportListener(TransportListener transportListener) {
573        this.listener = transportListener;
574    }
575
576    public ProtocolCodec getProtocolCodec() {
577        return codec;
578    }
579
580    public void setProtocolCodec(ProtocolCodec protocolCodec) throws Exception {
581        this.codec = protocolCodec;
582        if( channel!=null && codec!=null ) {
583            initializeCodec();
584        }
585    }
586
587    public boolean isConnected() {
588        return socketState.is(CONNECTED.class);
589    }
590
591    public boolean isClosed() {
592        return getServiceState() == STOPPED;
593    }
594
595    public boolean isUseLocalHost() {
596        return useLocalHost;
597    }
598
599    /**
600     * Sets whether 'localhost' or the actual local host name should be used to
601     * make local connections. On some operating systems such as Macs its not
602     * possible to connect as the local host name so localhost is better.
603     */
604    public void setUseLocalHost(boolean useLocalHost) {
605        this.useLocalHost = useLocalHost;
606    }
607
608    private void trace(String message) {
609        // TODO:
610    }
611
612    public DatagramChannel getDatagramChannel() {
613        return channel;
614    }
615
616    public ReadableByteChannel getReadChannel() {
617        return channel;
618    }
619
620    public WritableByteChannel getWriteChannel() {
621        return channel;
622    }
623
624    public int getTrafficClass() {
625        return trafficClass;
626    }
627
628    public void setTrafficClass(int trafficClass) {
629        this.trafficClass = trafficClass;
630    }
631
632    public int getReceiveBufferSize() {
633        return receiveBufferSize;
634    }
635
636    public void setReceiveBufferSize(int receiveBufferSize) {
637        this.receiveBufferSize = receiveBufferSize;
638    }
639
640    public int getSendBufferSize() {
641        return sendBufferSize;
642    }
643
644    public void setSendBufferSize(int sendBufferSize) {
645        this.sendBufferSize = sendBufferSize;
646    }
647
648    public Executor getBlockingExecutor() {
649        return blockingExecutor;
650    }
651
652    public void setBlockingExecutor(Executor blockingExecutor) {
653        this.blockingExecutor = blockingExecutor;
654    }
655
656}