Java: introducing websocket support.
This commit is contained in:
578
src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java
Normal file
578
src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java
Normal file
@@ -0,0 +1,578 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package nginx.unit.websocket;
|
||||
|
||||
import java.io.EOFException;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.AsynchronousSocketChannel;
|
||||
import java.nio.channels.CompletionHandler;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import javax.net.ssl.SSLEngine;
|
||||
import javax.net.ssl.SSLEngineResult;
|
||||
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
|
||||
import javax.net.ssl.SSLEngineResult.Status;
|
||||
import javax.net.ssl.SSLException;
|
||||
|
||||
import org.apache.juli.logging.Log;
|
||||
import org.apache.juli.logging.LogFactory;
|
||||
import org.apache.tomcat.util.res.StringManager;
|
||||
|
||||
/**
|
||||
* Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot
|
||||
* more testing before it can be considered robust.
|
||||
*/
|
||||
public class AsyncChannelWrapperSecure implements AsyncChannelWrapper {
|
||||
|
||||
private final Log log =
|
||||
LogFactory.getLog(AsyncChannelWrapperSecure.class);
|
||||
private static final StringManager sm =
|
||||
StringManager.getManager(AsyncChannelWrapperSecure.class);
|
||||
|
||||
private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921);
|
||||
private final AsynchronousSocketChannel socketChannel;
|
||||
private final SSLEngine sslEngine;
|
||||
private final ByteBuffer socketReadBuffer;
|
||||
private final ByteBuffer socketWriteBuffer;
|
||||
// One thread for read, one for write
|
||||
private final ExecutorService executor =
|
||||
Executors.newFixedThreadPool(2, new SecureIOThreadFactory());
|
||||
private AtomicBoolean writing = new AtomicBoolean(false);
|
||||
private AtomicBoolean reading = new AtomicBoolean(false);
|
||||
|
||||
public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel,
|
||||
SSLEngine sslEngine) {
|
||||
this.socketChannel = socketChannel;
|
||||
this.sslEngine = sslEngine;
|
||||
|
||||
int socketBufferSize = sslEngine.getSession().getPacketBufferSize();
|
||||
socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize);
|
||||
socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Future<Integer> read(ByteBuffer dst) {
|
||||
WrapperFuture<Integer,Void> future = new WrapperFuture<>();
|
||||
|
||||
if (!reading.compareAndSet(false, true)) {
|
||||
throw new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.concurrentRead"));
|
||||
}
|
||||
|
||||
ReadTask readTask = new ReadTask(dst, future);
|
||||
|
||||
executor.execute(readTask);
|
||||
|
||||
return future;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <B,A extends B> void read(ByteBuffer dst, A attachment,
|
||||
CompletionHandler<Integer,B> handler) {
|
||||
|
||||
WrapperFuture<Integer,B> future =
|
||||
new WrapperFuture<>(handler, attachment);
|
||||
|
||||
if (!reading.compareAndSet(false, true)) {
|
||||
throw new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.concurrentRead"));
|
||||
}
|
||||
|
||||
ReadTask readTask = new ReadTask(dst, future);
|
||||
|
||||
executor.execute(readTask);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Future<Integer> write(ByteBuffer src) {
|
||||
|
||||
WrapperFuture<Long,Void> inner = new WrapperFuture<>();
|
||||
|
||||
if (!writing.compareAndSet(false, true)) {
|
||||
throw new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.concurrentWrite"));
|
||||
}
|
||||
|
||||
WriteTask writeTask =
|
||||
new WriteTask(new ByteBuffer[] {src}, 0, 1, inner);
|
||||
|
||||
executor.execute(writeTask);
|
||||
|
||||
Future<Integer> future = new LongToIntegerFuture(inner);
|
||||
return future;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length,
|
||||
long timeout, TimeUnit unit, A attachment,
|
||||
CompletionHandler<Long,B> handler) {
|
||||
|
||||
WrapperFuture<Long,B> future =
|
||||
new WrapperFuture<>(handler, attachment);
|
||||
|
||||
if (!writing.compareAndSet(false, true)) {
|
||||
throw new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.concurrentWrite"));
|
||||
}
|
||||
|
||||
WriteTask writeTask = new WriteTask(srcs, offset, length, future);
|
||||
|
||||
executor.execute(writeTask);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
socketChannel.close();
|
||||
} catch (IOException e) {
|
||||
log.info(sm.getString("asyncChannelWrapperSecure.closeFail"));
|
||||
}
|
||||
executor.shutdownNow();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Future<Void> handshake() throws SSLException {
|
||||
|
||||
WrapperFuture<Void,Void> wFuture = new WrapperFuture<>();
|
||||
|
||||
Thread t = new WebSocketSslHandshakeThread(wFuture);
|
||||
t.start();
|
||||
|
||||
return wFuture;
|
||||
}
|
||||
|
||||
|
||||
private class WriteTask implements Runnable {
|
||||
|
||||
private final ByteBuffer[] srcs;
|
||||
private final int offset;
|
||||
private final int length;
|
||||
private final WrapperFuture<Long,?> future;
|
||||
|
||||
public WriteTask(ByteBuffer[] srcs, int offset, int length,
|
||||
WrapperFuture<Long,?> future) {
|
||||
this.srcs = srcs;
|
||||
this.future = future;
|
||||
this.offset = offset;
|
||||
this.length = length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
long written = 0;
|
||||
|
||||
try {
|
||||
for (int i = offset; i < offset + length; i++) {
|
||||
ByteBuffer src = srcs[i];
|
||||
while (src.hasRemaining()) {
|
||||
socketWriteBuffer.clear();
|
||||
|
||||
// Encrypt the data
|
||||
SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer);
|
||||
written += r.bytesConsumed();
|
||||
Status s = r.getStatus();
|
||||
|
||||
if (s == Status.OK || s == Status.BUFFER_OVERFLOW) {
|
||||
// Need to write out the bytes and may need to read from
|
||||
// the source again to empty it
|
||||
} else {
|
||||
// Status.BUFFER_UNDERFLOW - only happens on unwrap
|
||||
// Status.CLOSED - unexpected
|
||||
throw new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.statusWrap"));
|
||||
}
|
||||
|
||||
// Check for tasks
|
||||
if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
|
||||
Runnable runnable = sslEngine.getDelegatedTask();
|
||||
while (runnable != null) {
|
||||
runnable.run();
|
||||
runnable = sslEngine.getDelegatedTask();
|
||||
}
|
||||
}
|
||||
|
||||
socketWriteBuffer.flip();
|
||||
|
||||
// Do the write
|
||||
int toWrite = r.bytesProduced();
|
||||
while (toWrite > 0) {
|
||||
Future<Integer> f =
|
||||
socketChannel.write(socketWriteBuffer);
|
||||
Integer socketWrite = f.get();
|
||||
toWrite -= socketWrite.intValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (writing.compareAndSet(true, false)) {
|
||||
future.complete(Long.valueOf(written));
|
||||
} else {
|
||||
future.fail(new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.wrongStateWrite")));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
writing.set(false);
|
||||
future.fail(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private class ReadTask implements Runnable {
|
||||
|
||||
private final ByteBuffer dest;
|
||||
private final WrapperFuture<Integer,?> future;
|
||||
|
||||
public ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future) {
|
||||
this.dest = dest;
|
||||
this.future = future;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
int read = 0;
|
||||
|
||||
boolean forceRead = false;
|
||||
|
||||
try {
|
||||
while (read == 0) {
|
||||
socketReadBuffer.compact();
|
||||
|
||||
if (forceRead) {
|
||||
forceRead = false;
|
||||
Future<Integer> f = socketChannel.read(socketReadBuffer);
|
||||
Integer socketRead = f.get();
|
||||
if (socketRead.intValue() == -1) {
|
||||
throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof"));
|
||||
}
|
||||
}
|
||||
|
||||
socketReadBuffer.flip();
|
||||
|
||||
if (socketReadBuffer.hasRemaining()) {
|
||||
// Decrypt the data in the buffer
|
||||
SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest);
|
||||
read += r.bytesProduced();
|
||||
Status s = r.getStatus();
|
||||
|
||||
if (s == Status.OK) {
|
||||
// Bytes available for reading and there may be
|
||||
// sufficient data in the socketReadBuffer to
|
||||
// support further reads without reading from the
|
||||
// socket
|
||||
} else if (s == Status.BUFFER_UNDERFLOW) {
|
||||
// There is partial data in the socketReadBuffer
|
||||
if (read == 0) {
|
||||
// Need more data before the partial data can be
|
||||
// processed and some output generated
|
||||
forceRead = true;
|
||||
}
|
||||
// else return the data we have and deal with the
|
||||
// partial data on the next read
|
||||
} else if (s == Status.BUFFER_OVERFLOW) {
|
||||
// Not enough space in the destination buffer to
|
||||
// store all of the data. We could use a bytes read
|
||||
// value of -bufferSizeRequired to signal the new
|
||||
// buffer size required but an explicit exception is
|
||||
// clearer.
|
||||
if (reading.compareAndSet(true, false)) {
|
||||
throw new ReadBufferOverflowException(sslEngine.
|
||||
getSession().getApplicationBufferSize());
|
||||
} else {
|
||||
future.fail(new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.wrongStateRead")));
|
||||
}
|
||||
} else {
|
||||
// Status.CLOSED - unexpected
|
||||
throw new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.statusUnwrap"));
|
||||
}
|
||||
|
||||
// Check for tasks
|
||||
if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
|
||||
Runnable runnable = sslEngine.getDelegatedTask();
|
||||
while (runnable != null) {
|
||||
runnable.run();
|
||||
runnable = sslEngine.getDelegatedTask();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
forceRead = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (reading.compareAndSet(true, false)) {
|
||||
future.complete(Integer.valueOf(read));
|
||||
} else {
|
||||
future.fail(new IllegalStateException(sm.getString(
|
||||
"asyncChannelWrapperSecure.wrongStateRead")));
|
||||
}
|
||||
} catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException |
|
||||
ExecutionException | InterruptedException e) {
|
||||
reading.set(false);
|
||||
future.fail(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private class WebSocketSslHandshakeThread extends Thread {
|
||||
|
||||
private final WrapperFuture<Void,Void> hFuture;
|
||||
|
||||
private HandshakeStatus handshakeStatus;
|
||||
private Status resultStatus;
|
||||
|
||||
public WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture) {
|
||||
this.hFuture = hFuture;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
sslEngine.beginHandshake();
|
||||
// So the first compact does the right thing
|
||||
socketReadBuffer.position(socketReadBuffer.limit());
|
||||
|
||||
handshakeStatus = sslEngine.getHandshakeStatus();
|
||||
resultStatus = Status.OK;
|
||||
|
||||
boolean handshaking = true;
|
||||
|
||||
while(handshaking) {
|
||||
switch (handshakeStatus) {
|
||||
case NEED_WRAP: {
|
||||
socketWriteBuffer.clear();
|
||||
SSLEngineResult r =
|
||||
sslEngine.wrap(DUMMY, socketWriteBuffer);
|
||||
checkResult(r, true);
|
||||
socketWriteBuffer.flip();
|
||||
Future<Integer> fWrite =
|
||||
socketChannel.write(socketWriteBuffer);
|
||||
fWrite.get();
|
||||
break;
|
||||
}
|
||||
case NEED_UNWRAP: {
|
||||
socketReadBuffer.compact();
|
||||
if (socketReadBuffer.position() == 0 ||
|
||||
resultStatus == Status.BUFFER_UNDERFLOW) {
|
||||
Future<Integer> fRead =
|
||||
socketChannel.read(socketReadBuffer);
|
||||
fRead.get();
|
||||
}
|
||||
socketReadBuffer.flip();
|
||||
SSLEngineResult r =
|
||||
sslEngine.unwrap(socketReadBuffer, DUMMY);
|
||||
checkResult(r, false);
|
||||
break;
|
||||
}
|
||||
case NEED_TASK: {
|
||||
Runnable r = null;
|
||||
while ((r = sslEngine.getDelegatedTask()) != null) {
|
||||
r.run();
|
||||
}
|
||||
handshakeStatus = sslEngine.getHandshakeStatus();
|
||||
break;
|
||||
}
|
||||
case FINISHED: {
|
||||
handshaking = false;
|
||||
break;
|
||||
}
|
||||
case NOT_HANDSHAKING: {
|
||||
throw new SSLException(
|
||||
sm.getString("asyncChannelWrapperSecure.notHandshaking"));
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
hFuture.fail(e);
|
||||
return;
|
||||
}
|
||||
|
||||
hFuture.complete(null);
|
||||
}
|
||||
|
||||
private void checkResult(SSLEngineResult result, boolean wrap)
|
||||
throws SSLException {
|
||||
|
||||
handshakeStatus = result.getHandshakeStatus();
|
||||
resultStatus = result.getStatus();
|
||||
|
||||
if (resultStatus != Status.OK &&
|
||||
(wrap || resultStatus != Status.BUFFER_UNDERFLOW)) {
|
||||
throw new SSLException(
|
||||
sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus));
|
||||
}
|
||||
if (wrap && result.bytesConsumed() != 0) {
|
||||
throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap"));
|
||||
}
|
||||
if (!wrap && result.bytesProduced() != 0) {
|
||||
throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static class WrapperFuture<T,A> implements Future<T> {
|
||||
|
||||
private final CompletionHandler<T,A> handler;
|
||||
private final A attachment;
|
||||
|
||||
private volatile T result = null;
|
||||
private volatile Throwable throwable = null;
|
||||
private CountDownLatch completionLatch = new CountDownLatch(1);
|
||||
|
||||
public WrapperFuture() {
|
||||
this(null, null);
|
||||
}
|
||||
|
||||
public WrapperFuture(CompletionHandler<T,A> handler, A attachment) {
|
||||
this.handler = handler;
|
||||
this.attachment = attachment;
|
||||
}
|
||||
|
||||
public void complete(T result) {
|
||||
this.result = result;
|
||||
completionLatch.countDown();
|
||||
if (handler != null) {
|
||||
handler.completed(result, attachment);
|
||||
}
|
||||
}
|
||||
|
||||
public void fail(Throwable t) {
|
||||
throwable = t;
|
||||
completionLatch.countDown();
|
||||
if (handler != null) {
|
||||
handler.failed(throwable, attachment);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final boolean cancel(boolean mayInterruptIfRunning) {
|
||||
// Could support cancellation by closing the connection
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final boolean isCancelled() {
|
||||
// Could support cancellation by closing the connection
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final boolean isDone() {
|
||||
return completionLatch.getCount() > 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T get() throws InterruptedException, ExecutionException {
|
||||
completionLatch.await();
|
||||
if (throwable != null) {
|
||||
throw new ExecutionException(throwable);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T get(long timeout, TimeUnit unit)
|
||||
throws InterruptedException, ExecutionException,
|
||||
TimeoutException {
|
||||
boolean latchResult = completionLatch.await(timeout, unit);
|
||||
if (latchResult == false) {
|
||||
throw new TimeoutException();
|
||||
}
|
||||
if (throwable != null) {
|
||||
throw new ExecutionException(throwable);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class LongToIntegerFuture implements Future<Integer> {
|
||||
|
||||
private final Future<Long> wrapped;
|
||||
|
||||
public LongToIntegerFuture(Future<Long> wrapped) {
|
||||
this.wrapped = wrapped;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean cancel(boolean mayInterruptIfRunning) {
|
||||
return wrapped.cancel(mayInterruptIfRunning);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCancelled() {
|
||||
return wrapped.isCancelled();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDone() {
|
||||
return wrapped.isDone();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer get() throws InterruptedException, ExecutionException {
|
||||
Long result = wrapped.get();
|
||||
if (result.longValue() > Integer.MAX_VALUE) {
|
||||
throw new ExecutionException(sm.getString(
|
||||
"asyncChannelWrapperSecure.tooBig", result), null);
|
||||
}
|
||||
return Integer.valueOf(result.intValue());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer get(long timeout, TimeUnit unit)
|
||||
throws InterruptedException, ExecutionException,
|
||||
TimeoutException {
|
||||
Long result = wrapped.get(timeout, unit);
|
||||
if (result.longValue() > Integer.MAX_VALUE) {
|
||||
throw new ExecutionException(sm.getString(
|
||||
"asyncChannelWrapperSecure.tooBig", result), null);
|
||||
}
|
||||
return Integer.valueOf(result.intValue());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static class SecureIOThreadFactory implements ThreadFactory {
|
||||
|
||||
private AtomicInteger count = new AtomicInteger(0);
|
||||
|
||||
@Override
|
||||
public Thread newThread(Runnable r) {
|
||||
Thread t = new Thread(r);
|
||||
t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet());
|
||||
// No need to set the context class loader. The threads will be
|
||||
// cleaned up when the connection is closed.
|
||||
t.setDaemon(true);
|
||||
return t;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user