001    /**
002     * Copyright (C) 2009-2013 Barchart, Inc. <http://www.barchart.com/>
003     *
004     * All rights reserved. Licensed under the OSI BSD License.
005     *
006     * http://www.opensource.org/licenses/bsd-license.php
007     */
008    package com.barchart.udt.nio;
009    
010    import java.io.IOException;
011    import java.net.InetSocketAddress;
012    import java.net.SocketAddress;
013    import java.nio.ByteBuffer;
014    import java.nio.channels.ClosedChannelException;
015    import java.nio.channels.ConnectionPendingException;
016    import java.nio.channels.IllegalBlockingModeException;
017    import java.nio.channels.SocketChannel;
018    import java.nio.channels.UnresolvedAddressException;
019    
020    import org.slf4j.Logger;
021    import org.slf4j.LoggerFactory;
022    
023    import com.barchart.udt.ExceptionUDT;
024    import com.barchart.udt.SocketUDT;
025    import com.barchart.udt.TypeUDT;
026    import com.barchart.udt.anno.ThreadSafe;
027    
028    /**
029     * {@link SocketChannel}-like wrapper for {@link SocketUDT}, can be either
030     * stream or message oriented, depending on {@link TypeUDT}
031     * <p>
032     * The UDT socket that this SocketChannel wraps will be switched to blocking
033     * mode since this is the default for all SocketChannels on construction. If you
034     * require non-blocking functionality, you will need to call configureBlocking
035     * on the constructed SocketChannel class.
036     * <p>
037     * you must use {@link SelectorProviderUDT#openSocketChannel()} to obtain
038     * instance of this class; do not use JDK
039     * {@link java.nio.channels.SocketChannel#open()};
040     * <p>
041     * example:
042     * 
043     * <pre>
044     * SelectorProvider provider = SelectorProviderUDT.DATAGRAM;
045     * SocketChannel clientChannel = provider.openSocketChannel();
046     * clientChannel.configureBlocking(true);
047     * Socket clientSocket = clientChannel.socket();
048     * InetSocketAddress clientAddress = new InetSocketAddress(&quot;localhost&quot;, 10000);
049     * clientSocket.bind(clientAddress);
050     * assert clientSocket.isBound();
051     * InetSocketAddress serverAddress = new InetSocketAddress(&quot;localhost&quot;, 12345);
052     * clientChannel.connect(serverAddress);
053     * assert clientSocket.isConnected();
054     * </pre>
055     */
056    public class SocketChannelUDT extends SocketChannel implements ChannelUDT {
057    
058            protected static final Logger log = LoggerFactory
059                            .getLogger(SocketChannelUDT.class);
060    
061            protected final Object connectLock = new Object();
062    
063            /**
064             * local volatile variable, which mirrors super.blocking, to avoid the cost
065             * of synchronized call inside isBlocking()
066             */
067            protected volatile boolean isBlockingMode = isBlocking();
068    
069            protected volatile boolean isConnectFinished;
070    
071            protected volatile boolean isConnectionPending;
072    
073            @ThreadSafe("this")
074            protected NioSocketUDT socketAdapter;
075    
076            protected final SocketUDT socketUDT;
077    
078            protected SocketChannelUDT( //
079                            final SelectorProviderUDT provider, //
080                            final SocketUDT socketUDT //
081            ) throws ExceptionUDT {
082    
083                    super(provider);
084                    this.socketUDT = socketUDT;
085                    this.socketUDT.setBlocking(true);
086            }
087    
088            protected SocketChannelUDT( //
089                            final SelectorProviderUDT provider, //
090                            final SocketUDT socketUDT, //
091                            final boolean isConnected //
092            ) throws ExceptionUDT {
093    
094                    this(provider, socketUDT);
095    
096                    if (isConnected) {
097                            isConnectFinished = true;
098                            isConnectionPending = false;
099                    } else {
100                            isConnectFinished = false;
101                            isConnectionPending = true;
102                    }
103    
104            }
105    
106            @Override
107            public boolean connect(final SocketAddress remote) throws IOException {
108    
109                    if (!isOpen()) {
110                            throw new ClosedChannelException();
111                    }
112    
113                    if (isConnected()) {
114                            log.warn("already connected; ignoring remote={}", remote);
115                            return true;
116                    }
117    
118                    if (remote == null) {
119                            close();
120                            log.error("remote == null");
121                            throw new NullPointerException();
122                    }
123    
124                    final InetSocketAddress remoteSocket = (InetSocketAddress) remote;
125    
126                    if (remoteSocket.isUnresolved()) {
127                            log.error("can not use unresolved address: remote={}", remote);
128                            close();
129                            throw new UnresolvedAddressException();
130                    }
131    
132                    if (isBlocking()) {
133                            synchronized (connectLock) {
134                                    try {
135    
136                                            if (isConnectionPending) {
137                                                    close();
138                                                    throw new ConnectionPendingException();
139                                            }
140    
141                                            isConnectionPending = true;
142    
143                                            begin();
144    
145                                            socketUDT.connect(remoteSocket);
146    
147                                    } finally {
148    
149                                            end(true);
150    
151                                            isConnectionPending = false;
152    
153                                            connectLock.notifyAll();
154    
155                                    }
156                            }
157    
158                            return socketUDT.isConnected();
159    
160                    } else {
161    
162                            /** non Blocking */
163    
164                            if (!isRegistered()) {
165    
166                                    /** this channel is independent of any selector */
167    
168                                    log.error("UDT channel is in NON blocking mode; "
169                                                    + "must register with a selector " //
170                                                    + "before trying to connect(); " //
171                                                    + "socketId=" + socketUDT.id());
172    
173                                    throw new IllegalBlockingModeException();
174    
175                            }
176    
177                            /** this channel is registered with a selector */
178    
179                            synchronized (connectLock) {
180    
181                                    if (isConnectionPending) {
182                                            close();
183                                            log.error("connection already in progress");
184                                            throw new ConnectionPendingException();
185                                    }
186    
187                                    isConnectFinished = false;
188                                    isConnectionPending = true;
189    
190                                    socketUDT.connect(remoteSocket);
191    
192                            }
193    
194                            /**
195                             * connection operation must later be completed by invoking the
196                             * #finishConnect() method.
197                             */
198    
199                            return false;
200    
201                    }
202    
203            }
204    
205            @Override
206            public boolean finishConnect() throws IOException {
207    
208                    if (!isOpen()) {
209                            throw new ClosedChannelException();
210                    }
211    
212                    if (isBlocking()) {
213    
214                            synchronized (connectLock) {
215                                    while (isConnectionPending) {
216                                            try {
217                                                    connectLock.wait();
218                                            } catch (final InterruptedException e) {
219                                                    throw new IOException(e);
220                                            }
221                                    }
222                            }
223    
224                    }
225    
226                    if (isConnected()) {
227    
228                            isConnectFinished = true;
229                            isConnectionPending = false;
230    
231                            return true;
232    
233                    } else {
234    
235                            log.error("connect failure : {}", socketUDT);
236                            throw new IOException();
237    
238                    }
239    
240            }
241    
242            @Override
243            protected void implCloseSelectableChannel() throws IOException {
244                    socketUDT.close();
245            }
246    
247            @Override
248            protected void implConfigureBlocking(final boolean block)
249                            throws IOException {
250                    socketUDT.setBlocking(block);
251                    isBlockingMode = block;
252            }
253    
254            @Override
255            public boolean isConnected() {
256                    return socketUDT.isConnected();
257            }
258    
259            @Override
260            public boolean isConnectFinished() {
261                    return isConnectFinished;
262            }
263    
264            @Override
265            public boolean isConnectionPending() {
266                    return isConnectionPending;
267            }
268    
269            @Override
270            public KindUDT kindUDT() {
271                    return KindUDT.CONNECTOR;
272            }
273    
274            @Override
275            public SelectorProviderUDT providerUDT() {
276                    return (SelectorProviderUDT) super.provider();
277            }
278    
279            //
280    
281            /**
282             * See {@link java.nio.channels.SocketChannel#read(ByteBuffer)} contract;
283             * note: this method does not return (-1) as EOS (end of stream flag)
284             * 
285             * @return <code><0</code> should not happen<br>
286             *         <code>=0</code> blocking mode: timeout occurred on receive<br>
287             *         <code>=0</code> non-blocking mode: nothing is received by the
288             *         underlying UDT socket<br>
289             *         <code>>0</code> actual bytes received count<br>
290             * @see com.barchart.udt.SocketUDT#receive(ByteBuffer)
291             * @see com.barchart.udt.SocketUDT#receive(byte[], int, int)
292             */
293            @Override
294            public int read(final ByteBuffer buffer) throws IOException {
295    
296                    final int remaining = buffer.remaining();
297    
298                    if (remaining <= 0) {
299                            return 0;
300                    }
301    
302                    final SocketUDT socket = socketUDT;
303                    final boolean isBlocking = isBlockingMode;
304    
305                    final int sizeReceived;
306    
307                    try {
308    
309                            if (isBlocking) {
310                                    begin(); // JDK contract for NIO blocking calls
311                            }
312    
313                            if (buffer.isDirect()) {
314    
315                                    sizeReceived = socket.receive(buffer);
316    
317                            } else {
318    
319                                    final byte[] array = buffer.array();
320                                    final int position = buffer.position();
321                                    final int limit = buffer.limit();
322    
323                                    sizeReceived = socket.receive(array, position, limit);
324    
325                                    if (0 < sizeReceived && sizeReceived <= remaining) {
326                                            buffer.position(position + sizeReceived);
327                                    }
328    
329                            }
330    
331                    } finally {
332                            if (isBlocking) {
333                                    end(true); // JDK contract for NIO blocking calls
334                            }
335                    }
336    
337                    // see contract for receive()
338    
339                    if (sizeReceived < 0) {
340                            // log.trace("nothing was received; socket={}", socket);
341                            return 0;
342                    }
343    
344                    if (sizeReceived == 0) {
345                            // log.trace("receive timeout; socket={}", socket);
346                            return 0;
347                    }
348    
349                    if (sizeReceived <= remaining) {
350                            return sizeReceived;
351                    } else {
352                            log.error("should not happen: socket={}", socket);
353                            return 0;
354                    }
355    
356            }
357    
358            @Override
359            public long read(final ByteBuffer[] dsts, final int offset, final int length)
360                            throws IOException {
361                    throw new RuntimeException("feature not available");
362            }
363    
364            @Override
365            public synchronized NioSocketUDT socket() {
366                    if (socketAdapter == null) {
367                            try {
368                                    socketAdapter = new NioSocketUDT(this);
369                            } catch (final ExceptionUDT e) {
370                                    log.error("failed to make socket", e);
371                            }
372                    }
373                    return socketAdapter;
374            }
375    
376            @Override
377            public SocketUDT socketUDT() {
378                    return socketUDT;
379            }
380    
381            @Override
382            public String toString() {
383                    return socketUDT.toString();
384            }
385    
386            /**
387             * See {@link java.nio.channels.SocketChannel#write(ByteBuffer)} contract;
388             * 
389             * @return <code><0</code> should not happen<br>
390             *         <code>=0</code> blocking mode: timeout occurred on send<br>
391             *         <code>=0</code> non-blocking mode: buffer is full in the
392             *         underlying UDT socket; nothing is sent<br>
393             *         <code>>0</code> actual bytes sent count<br>
394             * @see com.barchart.udt.SocketUDT#send(ByteBuffer)
395             * @see com.barchart.udt.SocketUDT#send(byte[], int, int)
396             */
397            @Override
398            public int write(final ByteBuffer buffer) throws IOException {
399    
400                    // writeCount.incrementAndGet();
401    
402                    if (buffer == null) {
403                            throw new NullPointerException("buffer == null");
404                    }
405    
406                    final int remaining = buffer.remaining();
407    
408                    if (remaining <= 0) {
409                            return 0;
410                    }
411    
412                    final SocketUDT socket = socketUDT;
413                    final boolean isBlocking = isBlockingMode;
414    
415                    int sizeSent = 0;
416                    int ret = 0;
417    
418                    try {
419    
420                            if (isBlocking) {
421                                    begin(); // JDK contract for NIO blocking calls
422                            }
423    
424                            if (buffer.isDirect()) {
425    
426                                    do {
427                                            ret = socket.send(buffer);
428    
429                                            if (ret > 0)
430                                                    sizeSent += ret;
431    
432                                    } while (buffer.hasRemaining() && isBlocking);
433    
434                            } else {
435    
436                                    final byte[] array = buffer.array();
437                                    int position = buffer.position();
438                                    final int limit = buffer.limit();
439    
440                                    do {
441                                            ret = socket.send(array, position, limit);
442    
443                                            if (0 < ret && ret <= remaining) {
444                                                    sizeSent += ret;
445                                                    position += ret;
446                                                    buffer.position(position);
447                                            }
448    
449                                    } while (buffer.hasRemaining() && isBlocking);
450                            }
451                    } finally {
452                            if (isBlocking) {
453                                    end(true); // JDK contract for NIO blocking calls
454                            }
455                    }
456    
457                    // see contract for send()
458    
459                    if (ret < 0) {
460                            // log.trace("no buffer space; socket={}", socket);
461                            return 0;
462                    }
463    
464                    if (ret == 0) {
465                            // log.trace("send timeout; socket={}", socket);
466                            return 0;
467                    }
468    
469                    if (sizeSent <= remaining) {
470                            return sizeSent;
471                    } else {
472                            log.error("should not happen; socket={}", socket);
473                            return 0;
474                    }
475    
476            }
477    
478            @Override
479            public long write(final ByteBuffer[] bufferArray, final int offset,
480                            final int length) throws IOException {
481    
482                    try {
483    
484                            long total = 0;
485    
486                            for (int index = offset; index < offset + length; index++) {
487    
488                                    final ByteBuffer buffer = bufferArray[index];
489    
490                                    final int remaining = buffer.remaining();
491                                    final int processed = write(buffer);
492    
493                                    if (remaining == processed) {
494                                            total += processed;
495                                    } else {
496                                            throw new IllegalStateException(
497                                                            "failed to write buffer in array");
498                                    }
499    
500                            }
501    
502                            return total;
503    
504                    } catch (final Throwable e) {
505                            throw new IOException("failed to write buffer array", e);
506                    }
507    
508            }
509    
510            @Override
511            public TypeUDT typeUDT() {
512                    return providerUDT().type();
513            }
514    
515            /** java 7 */
516            public SocketChannelUDT bind(final SocketAddress localAddress)
517                            throws IOException {
518    
519                    socketUDT.bind((InetSocketAddress) localAddress);
520    
521                    return this;
522    
523            }
524    
525    }