Java: introducing websocket support.

This commit is contained in:
Max Romanov
2019-09-05 15:27:32 +03:00
parent 3e23afb0d2
commit 2b8cab1e24
113 changed files with 15422 additions and 70 deletions

View File

@@ -98,10 +98,14 @@ import javax.servlet.http.HttpSessionEvent;
import javax.servlet.http.HttpSessionIdListener;
import javax.servlet.http.HttpSessionListener;
import javax.websocket.server.ServerEndpoint;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import nginx.unit.websocket.WsSession;
import org.eclipse.jetty.http.MimeTypes;
import org.w3c.dom.Document;
@@ -421,6 +425,9 @@ public class Context implements ServletContext, InitParams
loader_ = new AppClassLoader(urls,
Context.class.getClassLoader().getParent());
Class wsSession_class = WsSession.class;
trace("wsSession.test: " + WsSession.wsSession_test());
ClassLoader old = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(loader_);
@@ -429,28 +436,30 @@ public class Context implements ServletContext, InitParams
addListener(listener_classname);
}
ScanResult scan_res = null;
ClassGraph classgraph = new ClassGraph()
//.verbose()
.overrideClassLoaders(loader_)
.ignoreParentClassLoaders()
.enableClassInfo()
.enableAnnotationInfo()
//.enableSystemPackages()
.whitelistModules("javax.*")
//.enableAllInfo()
;
String verbose = System.getProperty("nginx.unit.context.classgraph.verbose", "").trim();
if (verbose.equals("true")) {
classgraph.verbose();
}
ScanResult scan_res = classgraph.scan();
javax.websocket.server.ServerEndpointConfig.Configurator.setDefault(new nginx.unit.websocket.server.DefaultServerEndpointConfigurator());
loadInitializer(new nginx.unit.websocket.server.WsSci(), scan_res);
if (!metadata_complete_) {
ClassGraph classgraph = new ClassGraph()
//.verbose()
.overrideClassLoaders(loader_)
.ignoreParentClassLoaders()
.enableClassInfo()
.enableAnnotationInfo()
//.enableSystemPackages()
.whitelistModules("javax.*")
//.enableAllInfo()
;
String verbose = System.getProperty("nginx.unit.context.classgraph.verbose", "").trim();
if (verbose.equals("true")) {
classgraph.verbose();
}
scan_res = classgraph.scan();
loadInitializers(scan_res);
}
@@ -1471,54 +1480,61 @@ public class Context implements ServletContext, InitParams
ServiceLoader.load(ServletContainerInitializer.class, loader_);
for (ServletContainerInitializer sci : initializers) {
loadInitializer(sci, scan_res);
}
}
trace("loadInitializers: initializer: " + sci.getClass().getName());
private void loadInitializer(ServletContainerInitializer sci, ScanResult scan_res)
{
trace("loadInitializer: initializer: " + sci.getClass().getName());
HandlesTypes ann = sci.getClass().getAnnotation(HandlesTypes.class);
if (ann == null) {
trace("loadInitializers: no HandlesTypes annotation");
continue;
}
HandlesTypes ann = sci.getClass().getAnnotation(HandlesTypes.class);
if (ann == null) {
trace("loadInitializer: no HandlesTypes annotation");
return;
}
Class<?>[] classes = ann.value();
if (classes == null) {
trace("loadInitializers: no handles classes");
continue;
}
Class<?>[] classes = ann.value();
if (classes == null) {
trace("loadInitializer: no handles classes");
return;
}
Set<Class<?>> handles_classes = new HashSet<>();
Set<Class<?>> handles_classes = new HashSet<>();
for (Class<?> c : classes) {
trace("loadInitializers: find handles: " + c.getName());
for (Class<?> c : classes) {
trace("loadInitializer: find handles: " + c.getName());
ClassInfoList handles = c.isInterface()
ClassInfoList handles =
c.isAnnotation()
? scan_res.getClassesWithAnnotation(c.getName())
: c.isInterface()
? scan_res.getClassesImplementing(c.getName())
: scan_res.getSubclasses(c.getName());
for (ClassInfo ci : handles) {
if (ci.isInterface()
|| ci.isAnnotation()
|| ci.isAbstract())
{
continue;
}
trace("loadInitializers: handles class: " + ci.getName());
handles_classes.add(ci.loadClass());
for (ClassInfo ci : handles) {
if (ci.isInterface()
|| ci.isAnnotation()
|| ci.isAbstract())
{
return;
}
}
if (handles_classes.isEmpty()) {
trace("loadInitializers: no handles implementations");
continue;
trace("loadInitializer: handles class: " + ci.getName());
handles_classes.add(ci.loadClass());
}
}
try {
sci.onStartup(handles_classes, this);
metadata_complete_ = true;
} catch(Exception e) {
System.err.println("loadInitializers: exception caught: " + e.toString());
}
if (handles_classes.isEmpty()) {
trace("loadInitializer: no handles implementations");
return;
}
try {
sci.onStartup(handles_classes, this);
metadata_complete_ = true;
} catch(Exception e) {
System.err.println("loadInitializer: exception caught: " + e.toString());
}
}
@@ -1691,6 +1707,21 @@ public class Context implements ServletContext, InitParams
listener_classnames_.add(ci.getName());
}
ClassInfoList endpoints = scan_res.getClassesWithAnnotation(ServerEndpoint.class.getName());
for (ClassInfo ci : endpoints) {
if (ci.isInterface()
|| ci.isAnnotation()
|| ci.isAbstract())
{
trace("scanClasses: skip server end point: " + ci.getName());
continue;
}
trace("scanClasses: server end point: " + ci.getName());
}
}
public void stop() throws IOException

View File

@@ -16,6 +16,7 @@ import java.lang.StringBuffer;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
@@ -65,6 +66,9 @@ import org.eclipse.jetty.http.MultiPartFormInputStream;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.MimeTypes;
import nginx.unit.websocket.WsSession;
import nginx.unit.websocket.WsIOException;
public class Request implements HttpServletRequest, DynamicPathRequest
{
private final Context context;
@@ -114,6 +118,9 @@ public class Request implements HttpServletRequest, DynamicPathRequest
private boolean request_session_id_from_url = false;
private Session session = null;
private WsSession wsSession = null;
private boolean skip_close_ws = false;
private final ServletRequestAttributeListener attr_listener;
public static final String BARE = "nginx.unit.request.bare";
@@ -1203,11 +1210,30 @@ public class Request implements HttpServletRequest, DynamicPathRequest
public <T extends HttpUpgradeHandler> T upgrade(
Class<T> httpUpgradeHandlerClass) throws java.io.IOException, ServletException
{
log("upgrade: " + httpUpgradeHandlerClass.getName());
trace("upgrade: " + httpUpgradeHandlerClass.getName());
return null;
T handler;
try {
handler = httpUpgradeHandlerClass.getConstructor().newInstance();
} catch (Exception e) {
throw new ServletException(e);
}
upgrade(req_info_ptr);
return handler;
}
private static native void upgrade(long req_info_ptr);
public boolean isUpgrade()
{
return isUpgrade(req_info_ptr);
}
private static native boolean isUpgrade(long req_info_ptr);
@Override
public String changeSessionId()
{
@@ -1248,5 +1274,65 @@ public class Request implements HttpServletRequest, DynamicPathRequest
public static native void trace(long req_info_ptr, String msg, int msg_len);
private static native Response getResponse(long req_info_ptr);
public void setWsSession(WsSession s)
{
wsSession = s;
}
private void processWsFrame(ByteBuffer buf, byte opCode, boolean last)
throws IOException
{
trace("processWsFrame: " + opCode + ", [" + buf.position() + ", " + buf.limit() + "]");
try {
wsSession.processFrame(buf, opCode, last);
} catch (WsIOException e) {
wsSession.onClose(e.getCloseReason());
}
}
private void closeWsSession()
{
trace("closeWsSession");
skip_close_ws = true;
wsSession.onClose();
}
public void sendWsFrame(ByteBuffer payload, byte opCode, boolean last,
long timeoutExpiry) throws IOException
{
trace("sendWsFrame: " + opCode + ", [" + payload.position() +
", " + payload.limit() + "]");
if (payload.isDirect()) {
sendWsFrame(req_info_ptr, payload, payload.position(),
payload.limit() - payload.position(), opCode, last);
} else {
sendWsFrame(req_info_ptr, payload.array(), payload.position(),
payload.limit() - payload.position(), opCode, last);
}
}
private static native void sendWsFrame(long req_info_ptr,
ByteBuffer buf, int pos, int len, byte opCode, boolean last);
private static native void sendWsFrame(long req_info_ptr,
byte[] arr, int pos, int len, byte opCode, boolean last);
public void closeWs()
{
if (skip_close_ws) {
return;
}
trace("closeWs");
closeWs(req_info_ptr);
}
private static native void closeWs(long req_info_ptr);
}

View File

@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import java.nio.channels.AsynchronousChannelGroup;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.util.threads.ThreadPoolExecutor;
/**
* This is a utility class that enables multiple {@link WsWebSocketContainer}
* instances to share a single {@link AsynchronousChannelGroup} while ensuring
* that the group is destroyed when no longer required.
*/
public class AsyncChannelGroupUtil {
private static final StringManager sm =
StringManager.getManager(AsyncChannelGroupUtil.class);
private static AsynchronousChannelGroup group = null;
private static int usageCount = 0;
private static final Object lock = new Object();
private AsyncChannelGroupUtil() {
// Hide the default constructor
}
public static AsynchronousChannelGroup register() {
synchronized (lock) {
if (usageCount == 0) {
group = createAsynchronousChannelGroup();
}
usageCount++;
return group;
}
}
public static void unregister() {
synchronized (lock) {
usageCount--;
if (usageCount == 0) {
group.shutdown();
group = null;
}
}
}
private static AsynchronousChannelGroup createAsynchronousChannelGroup() {
// Need to do this with the right thread context class loader else the
// first web app to call this will trigger a leak
ClassLoader original = Thread.currentThread().getContextClassLoader();
try {
Thread.currentThread().setContextClassLoader(
AsyncIOThreadFactory.class.getClassLoader());
// These are the same settings as the default
// AsynchronousChannelGroup
int initialSize = Runtime.getRuntime().availableProcessors();
ExecutorService executorService = new ThreadPoolExecutor(
0,
Integer.MAX_VALUE,
Long.MAX_VALUE, TimeUnit.MILLISECONDS,
new SynchronousQueue<Runnable>(),
new AsyncIOThreadFactory());
try {
return AsynchronousChannelGroup.withCachedThreadPool(
executorService, initialSize);
} catch (IOException e) {
// No good reason for this to happen.
throw new IllegalStateException(sm.getString("asyncChannelGroup.createFail"));
}
} finally {
Thread.currentThread().setContextClassLoader(original);
}
}
private static class AsyncIOThreadFactory implements ThreadFactory {
static {
// Load NewThreadPrivilegedAction since newThread() will not be able
// to if called from an InnocuousThread.
// See https://bz.apache.org/bugzilla/show_bug.cgi?id=57490
NewThreadPrivilegedAction.load();
}
@Override
public Thread newThread(final Runnable r) {
// Create the new Thread within a doPrivileged block to ensure that
// the thread inherits the current ProtectionDomain which is
// essential to be able to use this with a Java Applet. See
// https://bz.apache.org/bugzilla/show_bug.cgi?id=57091
return AccessController.doPrivileged(new NewThreadPrivilegedAction(r));
}
// Non-anonymous class so that AsyncIOThreadFactory can load it
// explicitly
private static class NewThreadPrivilegedAction implements PrivilegedAction<Thread> {
private static AtomicInteger count = new AtomicInteger(0);
private final Runnable r;
public NewThreadPrivilegedAction(Runnable r) {
this.r = r;
}
@Override
public Thread run() {
Thread t = new Thread(r);
t.setName("WebSocketClient-AsyncIO-" + count.incrementAndGet());
t.setContextClassLoader(this.getClass().getClassLoader());
t.setDaemon(true);
return t;
}
private static void load() {
// NO-OP. Just provides a hook to enable the class to be loaded
}
}
}
}

View File

@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
/**
* This is a wrapper for a {@link java.nio.channels.AsynchronousSocketChannel}
* that limits the methods available thereby simplifying the process of
* implementing SSL/TLS support since there are fewer methods to intercept.
*/
public interface AsyncChannelWrapper {
Future<Integer> read(ByteBuffer dst);
<B,A extends B> void read(ByteBuffer dst, A attachment,
CompletionHandler<Integer,B> handler);
Future<Integer> write(ByteBuffer src);
<B,A extends B> void write(ByteBuffer[] srcs, int offset, int length,
long timeout, TimeUnit unit, A attachment,
CompletionHandler<Long,B> handler);
void close();
Future<Void> handshake() throws SSLException;
}

View File

@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/**
* Generally, just passes calls straight to the wrapped
* {@link AsynchronousSocketChannel}. In some cases exceptions may be swallowed
* to save them being swallowed by the calling code.
*/
public class AsyncChannelWrapperNonSecure implements AsyncChannelWrapper {
private static final Future<Void> NOOP_FUTURE = new NoOpFuture();
private final AsynchronousSocketChannel socketChannel;
public AsyncChannelWrapperNonSecure(
AsynchronousSocketChannel socketChannel) {
this.socketChannel = socketChannel;
}
@Override
public Future<Integer> read(ByteBuffer dst) {
return socketChannel.read(dst);
}
@Override
public <B,A extends B> void read(ByteBuffer dst, A attachment,
CompletionHandler<Integer,B> handler) {
socketChannel.read(dst, attachment, handler);
}
@Override
public Future<Integer> write(ByteBuffer src) {
return socketChannel.write(src);
}
@Override
public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length,
long timeout, TimeUnit unit, A attachment,
CompletionHandler<Long,B> handler) {
socketChannel.write(
srcs, offset, length, timeout, unit, attachment, handler);
}
@Override
public void close() {
try {
socketChannel.close();
} catch (IOException e) {
// Ignore
}
}
@Override
public Future<Void> handshake() {
return NOOP_FUTURE;
}
private static final class NoOpFuture implements Future<Void> {
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return false;
}
@Override
public boolean isCancelled() {
return false;
}
@Override
public boolean isDone() {
return true;
}
@Override
public Void get() throws InterruptedException, ExecutionException {
return null;
}
@Override
public Void get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException,
TimeoutException {
return null;
}
}
}

View File

@@ -0,0 +1,578 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.res.StringManager;
/**
* Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot
* more testing before it can be considered robust.
*/
public class AsyncChannelWrapperSecure implements AsyncChannelWrapper {
private final Log log =
LogFactory.getLog(AsyncChannelWrapperSecure.class);
private static final StringManager sm =
StringManager.getManager(AsyncChannelWrapperSecure.class);
private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921);
private final AsynchronousSocketChannel socketChannel;
private final SSLEngine sslEngine;
private final ByteBuffer socketReadBuffer;
private final ByteBuffer socketWriteBuffer;
// One thread for read, one for write
private final ExecutorService executor =
Executors.newFixedThreadPool(2, new SecureIOThreadFactory());
private AtomicBoolean writing = new AtomicBoolean(false);
private AtomicBoolean reading = new AtomicBoolean(false);
public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel,
SSLEngine sslEngine) {
this.socketChannel = socketChannel;
this.sslEngine = sslEngine;
int socketBufferSize = sslEngine.getSession().getPacketBufferSize();
socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize);
socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize);
}
@Override
public Future<Integer> read(ByteBuffer dst) {
WrapperFuture<Integer,Void> future = new WrapperFuture<>();
if (!reading.compareAndSet(false, true)) {
throw new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.concurrentRead"));
}
ReadTask readTask = new ReadTask(dst, future);
executor.execute(readTask);
return future;
}
@Override
public <B,A extends B> void read(ByteBuffer dst, A attachment,
CompletionHandler<Integer,B> handler) {
WrapperFuture<Integer,B> future =
new WrapperFuture<>(handler, attachment);
if (!reading.compareAndSet(false, true)) {
throw new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.concurrentRead"));
}
ReadTask readTask = new ReadTask(dst, future);
executor.execute(readTask);
}
@Override
public Future<Integer> write(ByteBuffer src) {
WrapperFuture<Long,Void> inner = new WrapperFuture<>();
if (!writing.compareAndSet(false, true)) {
throw new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.concurrentWrite"));
}
WriteTask writeTask =
new WriteTask(new ByteBuffer[] {src}, 0, 1, inner);
executor.execute(writeTask);
Future<Integer> future = new LongToIntegerFuture(inner);
return future;
}
@Override
public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length,
long timeout, TimeUnit unit, A attachment,
CompletionHandler<Long,B> handler) {
WrapperFuture<Long,B> future =
new WrapperFuture<>(handler, attachment);
if (!writing.compareAndSet(false, true)) {
throw new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.concurrentWrite"));
}
WriteTask writeTask = new WriteTask(srcs, offset, length, future);
executor.execute(writeTask);
}
@Override
public void close() {
try {
socketChannel.close();
} catch (IOException e) {
log.info(sm.getString("asyncChannelWrapperSecure.closeFail"));
}
executor.shutdownNow();
}
@Override
public Future<Void> handshake() throws SSLException {
WrapperFuture<Void,Void> wFuture = new WrapperFuture<>();
Thread t = new WebSocketSslHandshakeThread(wFuture);
t.start();
return wFuture;
}
private class WriteTask implements Runnable {
private final ByteBuffer[] srcs;
private final int offset;
private final int length;
private final WrapperFuture<Long,?> future;
public WriteTask(ByteBuffer[] srcs, int offset, int length,
WrapperFuture<Long,?> future) {
this.srcs = srcs;
this.future = future;
this.offset = offset;
this.length = length;
}
@Override
public void run() {
long written = 0;
try {
for (int i = offset; i < offset + length; i++) {
ByteBuffer src = srcs[i];
while (src.hasRemaining()) {
socketWriteBuffer.clear();
// Encrypt the data
SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer);
written += r.bytesConsumed();
Status s = r.getStatus();
if (s == Status.OK || s == Status.BUFFER_OVERFLOW) {
// Need to write out the bytes and may need to read from
// the source again to empty it
} else {
// Status.BUFFER_UNDERFLOW - only happens on unwrap
// Status.CLOSED - unexpected
throw new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.statusWrap"));
}
// Check for tasks
if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
Runnable runnable = sslEngine.getDelegatedTask();
while (runnable != null) {
runnable.run();
runnable = sslEngine.getDelegatedTask();
}
}
socketWriteBuffer.flip();
// Do the write
int toWrite = r.bytesProduced();
while (toWrite > 0) {
Future<Integer> f =
socketChannel.write(socketWriteBuffer);
Integer socketWrite = f.get();
toWrite -= socketWrite.intValue();
}
}
}
if (writing.compareAndSet(true, false)) {
future.complete(Long.valueOf(written));
} else {
future.fail(new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.wrongStateWrite")));
}
} catch (Exception e) {
writing.set(false);
future.fail(e);
}
}
}
private class ReadTask implements Runnable {
private final ByteBuffer dest;
private final WrapperFuture<Integer,?> future;
public ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future) {
this.dest = dest;
this.future = future;
}
@Override
public void run() {
int read = 0;
boolean forceRead = false;
try {
while (read == 0) {
socketReadBuffer.compact();
if (forceRead) {
forceRead = false;
Future<Integer> f = socketChannel.read(socketReadBuffer);
Integer socketRead = f.get();
if (socketRead.intValue() == -1) {
throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof"));
}
}
socketReadBuffer.flip();
if (socketReadBuffer.hasRemaining()) {
// Decrypt the data in the buffer
SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest);
read += r.bytesProduced();
Status s = r.getStatus();
if (s == Status.OK) {
// Bytes available for reading and there may be
// sufficient data in the socketReadBuffer to
// support further reads without reading from the
// socket
} else if (s == Status.BUFFER_UNDERFLOW) {
// There is partial data in the socketReadBuffer
if (read == 0) {
// Need more data before the partial data can be
// processed and some output generated
forceRead = true;
}
// else return the data we have and deal with the
// partial data on the next read
} else if (s == Status.BUFFER_OVERFLOW) {
// Not enough space in the destination buffer to
// store all of the data. We could use a bytes read
// value of -bufferSizeRequired to signal the new
// buffer size required but an explicit exception is
// clearer.
if (reading.compareAndSet(true, false)) {
throw new ReadBufferOverflowException(sslEngine.
getSession().getApplicationBufferSize());
} else {
future.fail(new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.wrongStateRead")));
}
} else {
// Status.CLOSED - unexpected
throw new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.statusUnwrap"));
}
// Check for tasks
if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
Runnable runnable = sslEngine.getDelegatedTask();
while (runnable != null) {
runnable.run();
runnable = sslEngine.getDelegatedTask();
}
}
} else {
forceRead = true;
}
}
if (reading.compareAndSet(true, false)) {
future.complete(Integer.valueOf(read));
} else {
future.fail(new IllegalStateException(sm.getString(
"asyncChannelWrapperSecure.wrongStateRead")));
}
} catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException |
ExecutionException | InterruptedException e) {
reading.set(false);
future.fail(e);
}
}
}
private class WebSocketSslHandshakeThread extends Thread {
private final WrapperFuture<Void,Void> hFuture;
private HandshakeStatus handshakeStatus;
private Status resultStatus;
public WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture) {
this.hFuture = hFuture;
}
@Override
public void run() {
try {
sslEngine.beginHandshake();
// So the first compact does the right thing
socketReadBuffer.position(socketReadBuffer.limit());
handshakeStatus = sslEngine.getHandshakeStatus();
resultStatus = Status.OK;
boolean handshaking = true;
while(handshaking) {
switch (handshakeStatus) {
case NEED_WRAP: {
socketWriteBuffer.clear();
SSLEngineResult r =
sslEngine.wrap(DUMMY, socketWriteBuffer);
checkResult(r, true);
socketWriteBuffer.flip();
Future<Integer> fWrite =
socketChannel.write(socketWriteBuffer);
fWrite.get();
break;
}
case NEED_UNWRAP: {
socketReadBuffer.compact();
if (socketReadBuffer.position() == 0 ||
resultStatus == Status.BUFFER_UNDERFLOW) {
Future<Integer> fRead =
socketChannel.read(socketReadBuffer);
fRead.get();
}
socketReadBuffer.flip();
SSLEngineResult r =
sslEngine.unwrap(socketReadBuffer, DUMMY);
checkResult(r, false);
break;
}
case NEED_TASK: {
Runnable r = null;
while ((r = sslEngine.getDelegatedTask()) != null) {
r.run();
}
handshakeStatus = sslEngine.getHandshakeStatus();
break;
}
case FINISHED: {
handshaking = false;
break;
}
case NOT_HANDSHAKING: {
throw new SSLException(
sm.getString("asyncChannelWrapperSecure.notHandshaking"));
}
}
}
} catch (Exception e) {
hFuture.fail(e);
return;
}
hFuture.complete(null);
}
private void checkResult(SSLEngineResult result, boolean wrap)
throws SSLException {
handshakeStatus = result.getHandshakeStatus();
resultStatus = result.getStatus();
if (resultStatus != Status.OK &&
(wrap || resultStatus != Status.BUFFER_UNDERFLOW)) {
throw new SSLException(
sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus));
}
if (wrap && result.bytesConsumed() != 0) {
throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap"));
}
if (!wrap && result.bytesProduced() != 0) {
throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap"));
}
}
}
private static class WrapperFuture<T,A> implements Future<T> {
private final CompletionHandler<T,A> handler;
private final A attachment;
private volatile T result = null;
private volatile Throwable throwable = null;
private CountDownLatch completionLatch = new CountDownLatch(1);
public WrapperFuture() {
this(null, null);
}
public WrapperFuture(CompletionHandler<T,A> handler, A attachment) {
this.handler = handler;
this.attachment = attachment;
}
public void complete(T result) {
this.result = result;
completionLatch.countDown();
if (handler != null) {
handler.completed(result, attachment);
}
}
public void fail(Throwable t) {
throwable = t;
completionLatch.countDown();
if (handler != null) {
handler.failed(throwable, attachment);
}
}
@Override
public final boolean cancel(boolean mayInterruptIfRunning) {
// Could support cancellation by closing the connection
return false;
}
@Override
public final boolean isCancelled() {
// Could support cancellation by closing the connection
return false;
}
@Override
public final boolean isDone() {
return completionLatch.getCount() > 0;
}
@Override
public T get() throws InterruptedException, ExecutionException {
completionLatch.await();
if (throwable != null) {
throw new ExecutionException(throwable);
}
return result;
}
@Override
public T get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException,
TimeoutException {
boolean latchResult = completionLatch.await(timeout, unit);
if (latchResult == false) {
throw new TimeoutException();
}
if (throwable != null) {
throw new ExecutionException(throwable);
}
return result;
}
}
private static final class LongToIntegerFuture implements Future<Integer> {
private final Future<Long> wrapped;
public LongToIntegerFuture(Future<Long> wrapped) {
this.wrapped = wrapped;
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return wrapped.cancel(mayInterruptIfRunning);
}
@Override
public boolean isCancelled() {
return wrapped.isCancelled();
}
@Override
public boolean isDone() {
return wrapped.isDone();
}
@Override
public Integer get() throws InterruptedException, ExecutionException {
Long result = wrapped.get();
if (result.longValue() > Integer.MAX_VALUE) {
throw new ExecutionException(sm.getString(
"asyncChannelWrapperSecure.tooBig", result), null);
}
return Integer.valueOf(result.intValue());
}
@Override
public Integer get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException,
TimeoutException {
Long result = wrapped.get(timeout, unit);
if (result.longValue() > Integer.MAX_VALUE) {
throw new ExecutionException(sm.getString(
"asyncChannelWrapperSecure.tooBig", result), null);
}
return Integer.valueOf(result.intValue());
}
}
private static class SecureIOThreadFactory implements ThreadFactory {
private AtomicInteger count = new AtomicInteger(0);
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet());
// No need to set the context class loader. The threads will be
// cleaned up when the connection is closed.
t.setDaemon(true);
return t;
}
}
}

View File

@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
/**
* Exception thrown on authentication error connecting to a remote
* websocket endpoint.
*/
public class AuthenticationException extends Exception {
private static final long serialVersionUID = 5709887412240096441L;
/**
* Create authentication exception.
* @param message the error message
*/
public AuthenticationException(String message) {
super(message);
}
}

View File

@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Base class for the authentication methods used by the websocket client.
*/
public abstract class Authenticator {
private static final Pattern pattern = Pattern
.compile("(\\w+)\\s*=\\s*(\"([^\"]+)\"|([^,=\"]+))\\s*,?");
/**
* Generate the authentication header that will be sent to the server.
* @param requestUri The request URI
* @param WWWAuthenticate The server auth challenge
* @param UserProperties The user information
* @return The auth header
* @throws AuthenticationException When an error occurs
*/
public abstract String getAuthorization(String requestUri, String WWWAuthenticate,
Map<String, Object> UserProperties) throws AuthenticationException;
/**
* Get the authentication method.
* @return the auth scheme
*/
public abstract String getSchemeName();
/**
* Utility method to parse the authentication header.
* @param WWWAuthenticate The server auth challenge
* @return the parsed header
*/
public Map<String, String> parseWWWAuthenticateHeader(String WWWAuthenticate) {
Matcher m = pattern.matcher(WWWAuthenticate);
Map<String, String> challenge = new HashMap<>();
while (m.find()) {
String key = m.group(1);
String qtedValue = m.group(3);
String value = m.group(4);
challenge.put(key, qtedValue != null ? qtedValue : value);
}
return challenge;
}
}

View File

@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.util.Iterator;
import java.util.ServiceLoader;
/**
* Utility method to return the appropriate authenticator according to
* the scheme that the server uses.
*/
public class AuthenticatorFactory {
/**
* Return a new authenticator instance.
* @param authScheme The scheme used
* @return the authenticator
*/
public static Authenticator getAuthenticator(String authScheme) {
Authenticator auth = null;
switch (authScheme.toLowerCase()) {
case BasicAuthenticator.schemeName:
auth = new BasicAuthenticator();
break;
case DigestAuthenticator.schemeName:
auth = new DigestAuthenticator();
break;
default:
auth = loadAuthenticators(authScheme);
break;
}
return auth;
}
private static Authenticator loadAuthenticators(String authScheme) {
ServiceLoader<Authenticator> serviceLoader = ServiceLoader.load(Authenticator.class);
Iterator<Authenticator> auths = serviceLoader.iterator();
while (auths.hasNext()) {
Authenticator auth = auths.next();
if (auth.getSchemeName().equalsIgnoreCase(authScheme))
return auth;
}
return null;
}
}

View File

@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
public interface BackgroundProcess {
void backgroundProcess();
void setProcessPeriod(int period);
int getProcessPeriod();
}

View File

@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.util.HashSet;
import java.util.Set;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.ExceptionUtils;
import org.apache.tomcat.util.res.StringManager;
/**
* Provides a background processing mechanism that triggers roughly once a
* second. The class maintains a thread that only runs when there is at least
* one instance of {@link BackgroundProcess} registered.
*/
public class BackgroundProcessManager {
private final Log log =
LogFactory.getLog(BackgroundProcessManager.class);
private static final StringManager sm =
StringManager.getManager(BackgroundProcessManager.class);
private static final BackgroundProcessManager instance;
static {
instance = new BackgroundProcessManager();
}
public static BackgroundProcessManager getInstance() {
return instance;
}
private final Set<BackgroundProcess> processes = new HashSet<>();
private final Object processesLock = new Object();
private WsBackgroundThread wsBackgroundThread = null;
private BackgroundProcessManager() {
// Hide default constructor
}
public void register(BackgroundProcess process) {
synchronized (processesLock) {
if (processes.size() == 0) {
wsBackgroundThread = new WsBackgroundThread(this);
wsBackgroundThread.setContextClassLoader(
this.getClass().getClassLoader());
wsBackgroundThread.setDaemon(true);
wsBackgroundThread.start();
}
processes.add(process);
}
}
public void unregister(BackgroundProcess process) {
synchronized (processesLock) {
processes.remove(process);
if (wsBackgroundThread != null && processes.size() == 0) {
wsBackgroundThread.halt();
wsBackgroundThread = null;
}
}
}
private void process() {
Set<BackgroundProcess> currentProcesses = new HashSet<>();
synchronized (processesLock) {
currentProcesses.addAll(processes);
}
for (BackgroundProcess process : currentProcesses) {
try {
process.backgroundProcess();
} catch (Throwable t) {
ExceptionUtils.handleThrowable(t);
log.error(sm.getString(
"backgroundProcessManager.processFailed"), t);
}
}
}
/*
* For unit testing.
*/
int getProcessCount() {
synchronized (processesLock) {
return processes.size();
}
}
void shutdown() {
synchronized (processesLock) {
processes.clear();
if (wsBackgroundThread != null) {
wsBackgroundThread.halt();
wsBackgroundThread = null;
}
}
}
private static class WsBackgroundThread extends Thread {
private final BackgroundProcessManager manager;
private volatile boolean running = true;
public WsBackgroundThread(BackgroundProcessManager manager) {
setName("WebSocket background processing");
this.manager = manager;
}
@Override
public void run() {
while (running) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
// Ignore
}
manager.process();
}
}
public void halt() {
setName("WebSocket background processing - stopping");
running = false;
}
}
}

View File

@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Map;
/**
* Authenticator supporting the BASIC auth method.
*/
public class BasicAuthenticator extends Authenticator {
public static final String schemeName = "basic";
public static final String charsetparam = "charset";
@Override
public String getAuthorization(String requestUri, String WWWAuthenticate,
Map<String, Object> userProperties) throws AuthenticationException {
String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME);
String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD);
if (userName == null || password == null) {
throw new AuthenticationException(
"Failed to perform Basic authentication due to missing user/password");
}
Map<String, String> wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate);
String userPass = userName + ":" + password;
Charset charset;
if (wwwAuthenticate.get(charsetparam) != null
&& wwwAuthenticate.get(charsetparam).equalsIgnoreCase("UTF-8")) {
charset = StandardCharsets.UTF_8;
} else {
charset = StandardCharsets.ISO_8859_1;
}
String base64 = Base64.getEncoder().encodeToString(userPass.getBytes(charset));
return " Basic " + base64;
}
@Override
public String getSchemeName() {
return schemeName;
}
}

View File

@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.websocket.Extension;
/**
* Internal implementation constants.
*/
public class Constants {
// OP Codes
public static final byte OPCODE_CONTINUATION = 0x00;
public static final byte OPCODE_TEXT = 0x01;
public static final byte OPCODE_BINARY = 0x02;
public static final byte OPCODE_CLOSE = 0x08;
public static final byte OPCODE_PING = 0x09;
public static final byte OPCODE_PONG = 0x0A;
// Internal OP Codes
// RFC 6455 limits OP Codes to 4 bits so these should never clash
// Always set bit 4 so these will be treated as control codes
static final byte INTERNAL_OPCODE_FLUSH = 0x18;
// Buffers
static final int DEFAULT_BUFFER_SIZE = Integer.getInteger(
"nginx.unit.websocket.DEFAULT_BUFFER_SIZE", 8 * 1024)
.intValue();
// Client connection
/**
* Property name to set to configure the value that is passed to
* {@link javax.net.ssl.SSLEngine#setEnabledProtocols(String[])}. The value
* should be a comma separated string.
*/
public static final String SSL_PROTOCOLS_PROPERTY =
"nginx.unit.websocket.SSL_PROTOCOLS";
public static final String SSL_TRUSTSTORE_PROPERTY =
"nginx.unit.websocket.SSL_TRUSTSTORE";
public static final String SSL_TRUSTSTORE_PWD_PROPERTY =
"nginx.unit.websocket.SSL_TRUSTSTORE_PWD";
public static final String SSL_TRUSTSTORE_PWD_DEFAULT = "changeit";
/**
* Property name to set to configure used SSLContext. The value should be an
* instance of SSLContext. If this property is present, the SSL_TRUSTSTORE*
* properties are ignored.
*/
public static final String SSL_CONTEXT_PROPERTY =
"nginx.unit.websocket.SSL_CONTEXT";
/**
* Property name to set to configure the timeout (in milliseconds) when
* establishing a WebSocket connection to server. The default is
* {@link #IO_TIMEOUT_MS_DEFAULT}.
*/
public static final String IO_TIMEOUT_MS_PROPERTY =
"nginx.unit.websocket.IO_TIMEOUT_MS";
public static final long IO_TIMEOUT_MS_DEFAULT = 5000;
// RFC 2068 recommended a limit of 5
// Most browsers have a default limit of 20
public static final String MAX_REDIRECTIONS_PROPERTY =
"nginx.unit.websocket.MAX_REDIRECTIONS";
public static final int MAX_REDIRECTIONS_DEFAULT = 20;
// HTTP upgrade header names and values
public static final String HOST_HEADER_NAME = "Host";
public static final String UPGRADE_HEADER_NAME = "Upgrade";
public static final String UPGRADE_HEADER_VALUE = "websocket";
public static final String ORIGIN_HEADER_NAME = "Origin";
public static final String CONNECTION_HEADER_NAME = "Connection";
public static final String CONNECTION_HEADER_VALUE = "upgrade";
public static final String LOCATION_HEADER_NAME = "Location";
public static final String AUTHORIZATION_HEADER_NAME = "Authorization";
public static final String WWW_AUTHENTICATE_HEADER_NAME = "WWW-Authenticate";
public static final String WS_VERSION_HEADER_NAME = "Sec-WebSocket-Version";
public static final String WS_VERSION_HEADER_VALUE = "13";
public static final String WS_KEY_HEADER_NAME = "Sec-WebSocket-Key";
public static final String WS_PROTOCOL_HEADER_NAME = "Sec-WebSocket-Protocol";
public static final String WS_EXTENSIONS_HEADER_NAME = "Sec-WebSocket-Extensions";
/// HTTP redirection status codes
public static final int MULTIPLE_CHOICES = 300;
public static final int MOVED_PERMANENTLY = 301;
public static final int FOUND = 302;
public static final int SEE_OTHER = 303;
public static final int USE_PROXY = 305;
public static final int TEMPORARY_REDIRECT = 307;
// Configuration for Origin header in client
static final String DEFAULT_ORIGIN_HEADER_VALUE =
System.getProperty("nginx.unit.websocket.DEFAULT_ORIGIN_HEADER_VALUE");
// Configuration for blocking sends
public static final String BLOCKING_SEND_TIMEOUT_PROPERTY =
"nginx.unit.websocket.BLOCKING_SEND_TIMEOUT";
// Milliseconds so this is 20 seconds
public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;
// Configuration for background processing checks intervals
static final int DEFAULT_PROCESS_PERIOD = Integer.getInteger(
"nginx.unit.websocket.DEFAULT_PROCESS_PERIOD", 10)
.intValue();
public static final String WS_AUTHENTICATION_USER_NAME = "nginx.unit.websocket.WS_AUTHENTICATION_USER_NAME";
public static final String WS_AUTHENTICATION_PASSWORD = "nginx.unit.websocket.WS_AUTHENTICATION_PASSWORD";
/* Configuration for extensions
* Note: These options are primarily present to enable this implementation
* to pass compliance tests. They are expected to be removed once
* the WebSocket API includes a mechanism for adding custom extensions
* and disabling built-in extensions.
*/
static final boolean DISABLE_BUILTIN_EXTENSIONS =
Boolean.getBoolean("nginx.unit.websocket.DISABLE_BUILTIN_EXTENSIONS");
static final boolean ALLOW_UNSUPPORTED_EXTENSIONS =
Boolean.getBoolean("nginx.unit.websocket.ALLOW_UNSUPPORTED_EXTENSIONS");
// Configuration for stream behavior
static final boolean STREAMS_DROP_EMPTY_MESSAGES =
Boolean.getBoolean("nginx.unit.websocket.STREAMS_DROP_EMPTY_MESSAGES");
public static final boolean STRICT_SPEC_COMPLIANCE =
Boolean.getBoolean("nginx.unit.websocket.STRICT_SPEC_COMPLIANCE");
public static final List<Extension> INSTALLED_EXTENSIONS;
static {
if (DISABLE_BUILTIN_EXTENSIONS) {
INSTALLED_EXTENSIONS = Collections.unmodifiableList(new ArrayList<Extension>());
} else {
List<Extension> installed = new ArrayList<>(1);
installed.add(new WsExtension("permessage-deflate"));
INSTALLED_EXTENSIONS = Collections.unmodifiableList(installed);
}
}
private Constants() {
// Hide default constructor
}
}

View File

@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import javax.websocket.Decoder;
public class DecoderEntry {
private final Class<?> clazz;
private final Class<? extends Decoder> decoderClazz;
public DecoderEntry(Class<?> clazz,
Class<? extends Decoder> decoderClazz) {
this.clazz = clazz;
this.decoderClazz = decoderClazz;
}
public Class<?> getClazz() {
return clazz;
}
public Class<? extends Decoder> getDecoderClazz() {
return decoderClazz;
}
}

View File

@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Map;
import org.apache.tomcat.util.security.MD5Encoder;
/**
* Authenticator supporting the DIGEST auth method.
*/
public class DigestAuthenticator extends Authenticator {
public static final String schemeName = "digest";
private SecureRandom cnonceGenerator;
private int nonceCount = 0;
private long cNonce;
@Override
public String getAuthorization(String requestUri, String WWWAuthenticate,
Map<String, Object> userProperties) throws AuthenticationException {
String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME);
String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD);
if (userName == null || password == null) {
throw new AuthenticationException(
"Failed to perform Digest authentication due to missing user/password");
}
Map<String, String> wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate);
String realm = wwwAuthenticate.get("realm");
String nonce = wwwAuthenticate.get("nonce");
String messageQop = wwwAuthenticate.get("qop");
String algorithm = wwwAuthenticate.get("algorithm") == null ? "MD5"
: wwwAuthenticate.get("algorithm");
String opaque = wwwAuthenticate.get("opaque");
StringBuilder challenge = new StringBuilder();
if (!messageQop.isEmpty()) {
if (cnonceGenerator == null) {
cnonceGenerator = new SecureRandom();
}
cNonce = cnonceGenerator.nextLong();
nonceCount++;
}
challenge.append("Digest ");
challenge.append("username =\"" + userName + "\",");
challenge.append("realm=\"" + realm + "\",");
challenge.append("nonce=\"" + nonce + "\",");
challenge.append("uri=\"" + requestUri + "\",");
try {
challenge.append("response=\"" + calculateRequestDigest(requestUri, userName, password,
realm, nonce, messageQop, algorithm) + "\",");
}
catch (NoSuchAlgorithmException e) {
throw new AuthenticationException(
"Unable to generate request digest " + e.getMessage());
}
challenge.append("algorithm=" + algorithm + ",");
challenge.append("opaque=\"" + opaque + "\",");
if (!messageQop.isEmpty()) {
challenge.append("qop=\"" + messageQop + "\"");
challenge.append(",cnonce=\"" + cNonce + "\",");
challenge.append("nc=" + String.format("%08X", Integer.valueOf(nonceCount)));
}
return challenge.toString();
}
private String calculateRequestDigest(String requestUri, String userName, String password,
String realm, String nonce, String qop, String algorithm)
throws NoSuchAlgorithmException {
StringBuilder preDigest = new StringBuilder();
String A1;
if (algorithm.equalsIgnoreCase("MD5"))
A1 = userName + ":" + realm + ":" + password;
else
A1 = encodeMD5(userName + ":" + realm + ":" + password) + ":" + nonce + ":" + cNonce;
/*
* If the "qop" value is "auth-int", then A2 is: A2 = Method ":"
* digest-uri-value ":" H(entity-body) since we do not have an entity-body, A2 =
* Method ":" digest-uri-value for auth and auth_int
*/
String A2 = "GET:" + requestUri;
preDigest.append(encodeMD5(A1));
preDigest.append(":");
preDigest.append(nonce);
if (qop.toLowerCase().contains("auth")) {
preDigest.append(":");
preDigest.append(String.format("%08X", Integer.valueOf(nonceCount)));
preDigest.append(":");
preDigest.append(String.valueOf(cNonce));
preDigest.append(":");
preDigest.append(qop);
}
preDigest.append(":");
preDigest.append(encodeMD5(A2));
return encodeMD5(preDigest.toString());
}
private String encodeMD5(String value) throws NoSuchAlgorithmException {
byte[] bytesOfMessage = value.getBytes(StandardCharsets.ISO_8859_1);
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] thedigest = md.digest(bytesOfMessage);
return MD5Encoder.encode(thedigest);
}
@Override
public String getSchemeName() {
return schemeName;
}
}

View File

@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import javax.websocket.SendHandler;
import javax.websocket.SendResult;
import org.apache.tomcat.util.res.StringManager;
/**
* Converts a Future to a SendHandler.
*/
class FutureToSendHandler implements Future<Void>, SendHandler {
private static final StringManager sm = StringManager.getManager(FutureToSendHandler.class);
private final CountDownLatch latch = new CountDownLatch(1);
private final WsSession wsSession;
private volatile AtomicReference<SendResult> result = new AtomicReference<>(null);
public FutureToSendHandler(WsSession wsSession) {
this.wsSession = wsSession;
}
// --------------------------------------------------------- SendHandler
@Override
public void onResult(SendResult result) {
this.result.compareAndSet(null, result);
latch.countDown();
}
// -------------------------------------------------------------- Future
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
// Cancelling the task is not supported
return false;
}
@Override
public boolean isCancelled() {
// Cancelling the task is not supported
return false;
}
@Override
public boolean isDone() {
return latch.getCount() == 0;
}
@Override
public Void get() throws InterruptedException,
ExecutionException {
try {
wsSession.registerFuture(this);
latch.await();
} finally {
wsSession.unregisterFuture(this);
}
if (result.get().getException() != null) {
throw new ExecutionException(result.get().getException());
}
return null;
}
@Override
public Void get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException,
TimeoutException {
boolean retval = false;
try {
wsSession.registerFuture(this);
retval = latch.await(timeout, unit);
} finally {
wsSession.unregisterFuture(this);
}
if (retval == false) {
throw new TimeoutException(sm.getString("futureToSendHandler.timeout",
Long.valueOf(timeout), unit.toString().toLowerCase()));
}
if (result.get().getException() != null) {
throw new ExecutionException(result.get().getException());
}
return null;
}
}

View File

@@ -0,0 +1,147 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
asyncChannelGroup.createFail=Unable to create dedicated AsynchronousChannelGroup for WebSocket clients which is required to prevent memory leaks in complex class loader environments like JavaEE containers
asyncChannelWrapperSecure.closeFail=Failed to close channel cleanly
asyncChannelWrapperSecure.check.notOk=TLS handshake returned an unexpected status [{0}]
asyncChannelWrapperSecure.check.unwrap=Bytes were written to the output during a read
asyncChannelWrapperSecure.check.wrap=Bytes were consumed from the input during a write
asyncChannelWrapperSecure.concurrentRead=Concurrent read operations are not permitted
asyncChannelWrapperSecure.concurrentWrite=Concurrent write operations are not permitted
asyncChannelWrapperSecure.eof=Unexpected end of stream
asyncChannelWrapperSecure.notHandshaking=Unexpected state [NOT_HANDSHAKING] during TLS handshake
asyncChannelWrapperSecure.readOverflow=Buffer overflow. [{0}] bytes to write into a [{1}] byte buffer that already contained [{2}] bytes.
asyncChannelWrapperSecure.statusUnwrap=Unexpected Status of SSLEngineResult after an unwrap() operation
asyncChannelWrapperSecure.statusWrap=Unexpected Status of SSLEngineResult after a wrap() operation
asyncChannelWrapperSecure.tooBig=The result [{0}] is too big to be expressed as an Integer
asyncChannelWrapperSecure.wrongStateRead=Flag that indicates a read is in progress was found to be false (it should have been true) when trying to complete a read operation
asyncChannelWrapperSecure.wrongStateWrite=Flag that indicates a write is in progress was found to be false (it should have been true) when trying to complete a write operation
backgroundProcessManager.processFailed=A background process failed
caseInsensitiveKeyMap.nullKey=Null keys are not permitted
futureToSendHandler.timeout=Operation timed out after waiting [{0}] [{1}] to complete
perMessageDeflate.deflateFailed=Failed to decompress a compressed WebSocket frame
perMessageDeflate.duplicateParameter=Duplicate definition of the [{0}] extension parameter
perMessageDeflate.invalidWindowSize=An invalid windows of [{1}] size was specified for [{0}]. Valid values are whole numbers from 8 to 15 inclusive.
perMessageDeflate.unknownParameter=An unknown extension parameter [{0}] was defined
transformerFactory.unsupportedExtension=The extension [{0}] is not supported
util.notToken=An illegal extension parameter was specified with name [{0}] and value [{1}]
util.invalidMessageHandler=The message handler provided does not have an onMessage(Object) method
util.invalidType=Unable to coerce value [{0}] to type [{1}]. That type is not supported.
util.unknownDecoderType=The Decoder type [{0}] is not recognized
# Note the wsFrame.* messages are used as close reasons in WebSocket control
# frames and therefore must be 123 bytes (not characters) or less in length.
# Messages are encoded using UTF-8 where a single character may be encoded in
# as many as 4 bytes.
wsFrame.alreadyResumed=Message receiving has already been resumed.
wsFrame.alreadySuspended=Message receiving has already been suspended.
wsFrame.bufferTooSmall=No async message support and buffer too small. Buffer size: [{0}], Message size: [{1}]
wsFrame.byteToLongFail=Too many bytes ([{0}]) were provided to be converted into a long
wsFrame.closed=New frame received after a close control frame
wsFrame.controlFragmented=A fragmented control frame was received but control frames may not be fragmented
wsFrame.controlPayloadTooBig=A control frame was sent with a payload of size [{0}] which is larger than the maximum permitted of 125 bytes
wsFrame.controlNoFin=A control frame was sent that did not have the fin bit set. Control frames are not permitted to use continuation frames.
wsFrame.illegalReadState=Unexpected read state [{0}]
wsFrame.invalidOpCode= A WebSocket frame was sent with an unrecognised opCode of [{0}]
wsFrame.invalidUtf8=A WebSocket text frame was received that could not be decoded to UTF-8 because it contained invalid byte sequences
wsFrame.invalidUtf8Close=A WebSocket close frame was received with a close reason that contained invalid UTF-8 byte sequences
wsFrame.ioeTriggeredClose=An unrecoverable IOException occurred so the connection was closed
wsFrame.messageTooBig=The message was [{0}] bytes long but the MessageHandler has a limit of [{1}] bytes
wsFrame.noContinuation=A new message was started when a continuation frame was expected
wsFrame.notMasked=The client frame was not masked but all client frames must be masked
wsFrame.oneByteCloseCode=The client sent a close frame with a single byte payload which is not valid
wsFrame.partialHeaderComplete=WebSocket frame received. fin [{0}], rsv [{1}], OpCode [{2}], payload length [{3}]
wsFrame.sessionClosed=The client data cannot be processed because the session has already been closed
wsFrame.suspendRequested=Suspend of the message receiving has already been requested.
wsFrame.textMessageTooBig=The decoded text message was too big for the output buffer and the endpoint does not support partial messages
wsFrame.wrongRsv=The client frame set the reserved bits to [{0}] for a message with opCode [{1}] which was not supported by this endpoint
wsFrameClient.ioe=Failure while reading data sent by server
wsHandshakeRequest.invalidUri=The string [{0}] cannot be used to construct a valid URI
wsHandshakeRequest.unknownScheme=The scheme [{0}] in the request is not recognised
wsRemoteEndpoint.acquireTimeout=The current message was not fully sent within the specified timeout
wsRemoteEndpoint.closed=Message will not be sent because the WebSocket session has been closed
wsRemoteEndpoint.closedDuringMessage=The remainder of the message will not be sent because the WebSocket session has been closed
wsRemoteEndpoint.closedOutputStream=This method may not be called as the OutputStream has been closed
wsRemoteEndpoint.closedWriter=This method may not be called as the Writer has been closed
wsRemoteEndpoint.changeType=When sending a fragmented message, all fragments must be of the same type
wsRemoteEndpoint.concurrentMessageSend=Messages may not be sent concurrently even when using the asynchronous send messages. The client must wait for the previous message to complete before sending the next.
wsRemoteEndpoint.flushOnCloseFailed=Batched messages still enabled after session has been closed. Unable to flush remaining batched message.
wsRemoteEndpoint.invalidEncoder=The specified encoder of type [{0}] could not be instantiated
wsRemoteEndpoint.noEncoder=No encoder specified for object of class [{0}]
wsRemoteEndpoint.nullData=Invalid null data argument
wsRemoteEndpoint.nullHandler=Invalid null handler argument
wsRemoteEndpoint.sendInterrupt=The current thread was interrupted while waiting for a blocking send to complete
wsRemoteEndpoint.tooMuchData=Ping or pong may not send more than 125 bytes
wsRemoteEndpoint.wrongState=The remote endpoint was in state [{0}] which is an invalid state for called method
# Note the following message is used as a close reason in a WebSocket control
# frame and therefore must be 123 bytes (not characters) or less in length.
# Messages are encoded using UTF-8 where a single character may be encoded in
# as many as 4 bytes.
wsSession.timeout=The WebSocket session [{0}] timeout expired
wsSession.closed=The WebSocket session [{0}] has been closed and no method (apart from close()) may be called on a closed session
wsSession.created=Created WebSocket session [{0}]
wsSession.doClose=Closing WebSocket session [{1}]
wsSession.duplicateHandlerBinary=A binary message handler has already been configured
wsSession.duplicateHandlerPong=A pong message handler has already been configured
wsSession.duplicateHandlerText=A text message handler has already been configured
wsSession.invalidHandlerTypePong=A pong message handler must implement MessageHandler.Whole
wsSession.flushFailOnClose=Failed to flush batched messages on session close
wsSession.messageFailed=Unable to write the complete message as the WebSocket connection has been closed
wsSession.sendCloseFail=Failed to send close message for session [{0}] to remote endpoint
wsSession.removeHandlerFailed=Unable to remove the handler [{0}] as it was not registered with this session
wsSession.unknownHandler=Unable to add the message handler [{0}] as it was for the unrecognised type [{1}]
wsSession.unknownHandlerType=Unable to add the message handler [{0}] as it was wrapped as the unrecognised type [{1}]
wsSession.instanceNew=Endpoint instance registration failed
wsSession.instanceDestroy=Endpoint instance unregistration failed
# Note the following message is used as a close reason in a WebSocket control
# frame and therefore must be 123 bytes (not characters) or less in length.
# Messages are encoded using UTF-8 where a single character may be encoded in
# as many as 4 bytes.
wsWebSocketContainer.shutdown=The web application is stopping
wsWebSocketContainer.defaultConfiguratorFail=Failed to create the default configurator
wsWebSocketContainer.endpointCreateFail=Failed to create a local endpoint of type [{0}]
wsWebSocketContainer.maxBuffer=This implementation limits the maximum size of a buffer to Integer.MAX_VALUE
wsWebSocketContainer.missingAnnotation=Cannot use POJO class [{0}] as it is not annotated with @ClientEndpoint
wsWebSocketContainer.sessionCloseFail=Session with ID [{0}] did not close cleanly
wsWebSocketContainer.asynchronousSocketChannelFail=Unable to open a connection to the server
wsWebSocketContainer.httpRequestFailed=The HTTP request to initiate the WebSocket connection failed
wsWebSocketContainer.invalidExtensionParameters=The server responded with extension parameters the client is unable to support
wsWebSocketContainer.invalidHeader=Unable to parse HTTP header as no colon is present to delimit header name and header value in [{0}]. The header has been skipped.
wsWebSocketContainer.invalidStatus=The HTTP response from the server [{0}] did not permit the HTTP upgrade to WebSocket
wsWebSocketContainer.invalidSubProtocol=The WebSocket server returned multiple values for the Sec-WebSocket-Protocol header
wsWebSocketContainer.pathNoHost=No host was specified in URI
wsWebSocketContainer.pathWrongScheme=The scheme [{0}] is not supported. The supported schemes are ws and wss
wsWebSocketContainer.proxyConnectFail=Failed to connect to the configured Proxy [{0}]. The HTTP response code was [{1}]
wsWebSocketContainer.sslEngineFail=Unable to create SSLEngine to support SSL/TLS connections
wsWebSocketContainer.missingLocationHeader=Failed to handle HTTP response code [{0}]. Missing Location header in response
wsWebSocketContainer.redirectThreshold=Cyclic Location header [{0}] detected / reached max number of redirects [{1}] of max [{2}]
wsWebSocketContainer.unsupportedAuthScheme=Failed to handle HTTP response code [{0}]. Unsupported Authentication scheme [{1}] returned in response
wsWebSocketContainer.failedAuthentication=Failed to handle HTTP response code [{0}]. Authentication header was not accepted by server.
wsWebSocketContainer.missingWWWAuthenticateHeader=Failed to handle HTTP response code [{0}]. Missing WWW-Authenticate header in response

View File

@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import javax.websocket.MessageHandler;
public class MessageHandlerResult {
private final MessageHandler handler;
private final MessageHandlerResultType type;
public MessageHandlerResult(MessageHandler handler,
MessageHandlerResultType type) {
this.handler = handler;
this.type = type;
}
public MessageHandler getHandler() {
return handler;
}
public MessageHandlerResultType getType() {
return type;
}
}

View File

@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
public enum MessageHandlerResultType {
BINARY,
TEXT,
PONG
}

View File

@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.nio.ByteBuffer;
import javax.websocket.SendHandler;
class MessagePart {
private final boolean fin;
private final int rsv;
private final byte opCode;
private final ByteBuffer payload;
private final SendHandler intermediateHandler;
private volatile SendHandler endHandler;
private final long blockingWriteTimeoutExpiry;
public MessagePart( boolean fin, int rsv, byte opCode, ByteBuffer payload,
SendHandler intermediateHandler, SendHandler endHandler,
long blockingWriteTimeoutExpiry) {
this.fin = fin;
this.rsv = rsv;
this.opCode = opCode;
this.payload = payload;
this.intermediateHandler = intermediateHandler;
this.endHandler = endHandler;
this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry;
}
public boolean isFin() {
return fin;
}
public int getRsv() {
return rsv;
}
public byte getOpCode() {
return opCode;
}
public ByteBuffer getPayload() {
return payload;
}
public SendHandler getIntermediateHandler() {
return intermediateHandler;
}
public SendHandler getEndHandler() {
return endHandler;
}
public void setEndHandler(SendHandler endHandler) {
this.endHandler = endHandler;
}
public long getBlockingWriteTimeoutExpiry() {
return blockingWriteTimeoutExpiry;
}
}

View File

@@ -0,0 +1,476 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import javax.websocket.Extension;
import javax.websocket.Extension.Parameter;
import javax.websocket.SendHandler;
import org.apache.tomcat.util.res.StringManager;
public class PerMessageDeflate implements Transformation {
private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class);
private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover";
private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover";
private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits";
private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";
private static final int RSV_BITMASK = 0b100;
private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1};
public static final String NAME = "permessage-deflate";
private final boolean serverContextTakeover;
private final int serverMaxWindowBits;
private final boolean clientContextTakeover;
private final int clientMaxWindowBits;
private final boolean isServer;
private final Inflater inflater = new Inflater(true);
private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1];
private volatile Transformation next;
private volatile boolean skipDecompression = false;
private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
private volatile boolean firstCompressedFrameWritten = false;
// Flag to track if a message is completely empty
private volatile boolean emptyMessage = true;
static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) {
// Accept the first preference that the endpoint is able to support
for (List<Parameter> preference : preferences) {
boolean ok = true;
boolean serverContextTakeover = true;
int serverMaxWindowBits = -1;
boolean clientContextTakeover = true;
int clientMaxWindowBits = -1;
for (Parameter param : preference) {
if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
if (serverContextTakeover) {
serverContextTakeover = false;
} else {
// Duplicate definition
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.duplicateParameter",
SERVER_NO_CONTEXT_TAKEOVER ));
}
} else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
if (clientContextTakeover) {
clientContextTakeover = false;
} else {
// Duplicate definition
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.duplicateParameter",
CLIENT_NO_CONTEXT_TAKEOVER ));
}
} else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) {
if (serverMaxWindowBits == -1) {
serverMaxWindowBits = Integer.parseInt(param.getValue());
if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) {
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.invalidWindowSize",
SERVER_MAX_WINDOW_BITS,
Integer.valueOf(serverMaxWindowBits)));
}
// Java SE API (as of Java 8) does not expose the API to
// control the Window size. It is effectively hard-coded
// to 15
if (isServer && serverMaxWindowBits != 15) {
ok = false;
break;
// Note server window size is not an issue for the
// client since the client will assume 15 and if the
// server uses a smaller window everything will
// still work
}
} else {
// Duplicate definition
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.duplicateParameter",
SERVER_MAX_WINDOW_BITS ));
}
} else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) {
if (clientMaxWindowBits == -1) {
if (param.getValue() == null) {
// Hint to server that the client supports this
// option. Java SE API (as of Java 8) does not
// expose the API to control the Window size. It is
// effectively hard-coded to 15
clientMaxWindowBits = 15;
} else {
clientMaxWindowBits = Integer.parseInt(param.getValue());
if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) {
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.invalidWindowSize",
CLIENT_MAX_WINDOW_BITS,
Integer.valueOf(clientMaxWindowBits)));
}
}
// Java SE API (as of Java 8) does not expose the API to
// control the Window size. It is effectively hard-coded
// to 15
if (!isServer && clientMaxWindowBits != 15) {
ok = false;
break;
// Note client window size is not an issue for the
// server since the server will assume 15 and if the
// client uses a smaller window everything will
// still work
}
} else {
// Duplicate definition
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.duplicateParameter",
CLIENT_MAX_WINDOW_BITS ));
}
} else {
// Unknown parameter
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.unknownParameter", param.getName()));
}
}
if (ok) {
return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits,
clientContextTakeover, clientMaxWindowBits, isServer);
}
}
// Failed to negotiate agreeable terms
return null;
}
private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits,
boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) {
this.serverContextTakeover = serverContextTakeover;
this.serverMaxWindowBits = serverMaxWindowBits;
this.clientContextTakeover = clientContextTakeover;
this.clientMaxWindowBits = clientMaxWindowBits;
this.isServer = isServer;
}
@Override
public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)
throws IOException {
// Control frames are never compressed and may appear in the middle of
// a WebSocket method. Pass them straight through.
if (Util.isControl(opCode)) {
return next.getMoreData(opCode, fin, rsv, dest);
}
if (!Util.isContinuation(opCode)) {
// First frame in new message
skipDecompression = (rsv & RSV_BITMASK) == 0;
}
// Pass uncompressed frames straight through.
if (skipDecompression) {
return next.getMoreData(opCode, fin, rsv, dest);
}
int written;
boolean usedEomBytes = false;
while (dest.remaining() > 0) {
// Space available in destination. Try and fill it.
try {
written = inflater.inflate(
dest.array(), dest.arrayOffset() + dest.position(), dest.remaining());
} catch (DataFormatException e) {
throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e);
}
dest.position(dest.position() + written);
if (inflater.needsInput() && !usedEomBytes ) {
if (dest.hasRemaining()) {
readBuffer.clear();
TransformationResult nextResult =
next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer);
inflater.setInput(
readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position());
if (TransformationResult.UNDERFLOW.equals(nextResult)) {
return nextResult;
} else if (TransformationResult.END_OF_FRAME.equals(nextResult) &&
readBuffer.position() == 0) {
if (fin) {
inflater.setInput(EOM_BYTES);
usedEomBytes = true;
} else {
return TransformationResult.END_OF_FRAME;
}
}
}
} else if (written == 0) {
if (fin && (isServer && !clientContextTakeover ||
!isServer && !serverContextTakeover)) {
inflater.reset();
}
return TransformationResult.END_OF_FRAME;
}
}
return TransformationResult.OVERFLOW;
}
@Override
public boolean validateRsv(int rsv, byte opCode) {
if (Util.isControl(opCode)) {
if ((rsv & RSV_BITMASK) != 0) {
return false;
} else {
if (next == null) {
return true;
} else {
return next.validateRsv(rsv, opCode);
}
}
} else {
int rsvNext = rsv;
if ((rsv & RSV_BITMASK) != 0) {
rsvNext = rsv ^ RSV_BITMASK;
}
if (next == null) {
return true;
} else {
return next.validateRsv(rsvNext, opCode);
}
}
}
@Override
public Extension getExtensionResponse() {
Extension result = new WsExtension(NAME);
List<Extension.Parameter> params = result.getParameters();
if (!serverContextTakeover) {
params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null));
}
if (serverMaxWindowBits != -1) {
params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS,
Integer.toString(serverMaxWindowBits)));
}
if (!clientContextTakeover) {
params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null));
}
if (clientMaxWindowBits != -1) {
params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS,
Integer.toString(clientMaxWindowBits)));
}
return result;
}
@Override
public void setNext(Transformation t) {
if (next == null) {
this.next = t;
} else {
next.setNext(t);
}
}
@Override
public boolean validateRsvBits(int i) {
if ((i & RSV_BITMASK) != 0) {
return false;
}
if (next == null) {
return true;
} else {
return next.validateRsvBits(i | RSV_BITMASK);
}
}
@Override
public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) {
List<MessagePart> allCompressedParts = new ArrayList<>();
for (MessagePart uncompressedPart : uncompressedParts) {
byte opCode = uncompressedPart.getOpCode();
boolean emptyPart = uncompressedPart.getPayload().limit() == 0;
emptyMessage = emptyMessage && emptyPart;
if (Util.isControl(opCode)) {
// Control messages can appear in the middle of other messages
// and must not be compressed. Pass it straight through
allCompressedParts.add(uncompressedPart);
} else if (emptyMessage && uncompressedPart.isFin()) {
// Zero length messages can't be compressed so pass the
// final (empty) part straight through.
allCompressedParts.add(uncompressedPart);
} else {
List<MessagePart> compressedParts = new ArrayList<>();
ByteBuffer uncompressedPayload = uncompressedPart.getPayload();
SendHandler uncompressedIntermediateHandler =
uncompressedPart.getIntermediateHandler();
deflater.setInput(uncompressedPayload.array(),
uncompressedPayload.arrayOffset() + uncompressedPayload.position(),
uncompressedPayload.remaining());
int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH);
boolean deflateRequired = true;
while (deflateRequired) {
ByteBuffer compressedPayload = writeBuffer;
int written = deflater.deflate(compressedPayload.array(),
compressedPayload.arrayOffset() + compressedPayload.position(),
compressedPayload.remaining(), flush);
compressedPayload.position(compressedPayload.position() + written);
if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) {
// This message part has been fully processed by the
// deflater. Fire the send handler for this message part
// and move on to the next message part.
break;
}
// If this point is reached, a new compressed message part
// will be created...
MessagePart compressedPart;
// .. and a new writeBuffer will be required.
writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
// Flip the compressed payload ready for writing
compressedPayload.flip();
boolean fin = uncompressedPart.isFin();
boolean full = compressedPayload.limit() == compressedPayload.capacity();
boolean needsInput = deflater.needsInput();
long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry();
if (fin && !full && needsInput) {
// End of compressed message. Drop EOM bytes and output.
compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length);
compressedPart = new MessagePart(true, getRsv(uncompressedPart),
opCode, compressedPayload, uncompressedIntermediateHandler,
uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
deflateRequired = false;
startNewMessage();
} else if (full && !needsInput) {
// Write buffer full and input message not fully read.
// Output and start new compressed part.
compressedPart = new MessagePart(false, getRsv(uncompressedPart),
opCode, compressedPayload, uncompressedIntermediateHandler,
uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
} else if (!fin && full && needsInput) {
// Write buffer full and input message not fully read.
// Output and get more data.
compressedPart = new MessagePart(false, getRsv(uncompressedPart),
opCode, compressedPayload, uncompressedIntermediateHandler,
uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
deflateRequired = false;
} else if (fin && full && needsInput) {
// Write buffer full. Input fully read. Deflater may be
// in one of four states:
// - output complete (just happened to align with end of
// buffer
// - in middle of EOM bytes
// - about to write EOM bytes
// - more data to write
int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH);
if (eomBufferWritten < EOM_BUFFER.length) {
// EOM has just been completed
compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten);
compressedPart = new MessagePart(true,
getRsv(uncompressedPart), opCode, compressedPayload,
uncompressedIntermediateHandler, uncompressedIntermediateHandler,
blockingWriteTimeoutExpiry);
deflateRequired = false;
startNewMessage();
} else {
// More data to write
// Copy bytes to new write buffer
writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten);
compressedPart = new MessagePart(false,
getRsv(uncompressedPart), opCode, compressedPayload,
uncompressedIntermediateHandler, uncompressedIntermediateHandler,
blockingWriteTimeoutExpiry);
}
} else {
throw new IllegalStateException("Should never happen");
}
// Add the newly created compressed part to the set of parts
// to pass on to the next transformation.
compressedParts.add(compressedPart);
}
SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler();
int size = compressedParts.size();
if (size > 0) {
compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler);
}
allCompressedParts.addAll(compressedParts);
}
}
if (next == null) {
return allCompressedParts;
} else {
return next.sendMessagePart(allCompressedParts);
}
}
private void startNewMessage() {
firstCompressedFrameWritten = false;
emptyMessage = true;
if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) {
deflater.reset();
}
}
private int getRsv(MessagePart uncompressedMessagePart) {
int result = uncompressedMessagePart.getRsv();
if (!firstCompressedFrameWritten) {
result += RSV_BITMASK;
firstCompressedFrameWritten = true;
}
return result;
}
@Override
public void close() {
// There will always be a next transformation
next.close();
inflater.end();
deflater.end();
}
}

View File

@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
public class ReadBufferOverflowException extends IOException {
private static final long serialVersionUID = 1L;
private final int minBufferSize;
public ReadBufferOverflowException(int minBufferSize) {
this.minBufferSize = minBufferSize;
}
public int getMinBufferSize() {
return minBufferSize;
}
}

View File

@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import javax.websocket.Extension;
/**
* The internal representation of the transformation that a WebSocket extension
* performs on a message.
*/
public interface Transformation {
/**
* Sets the next transformation in the pipeline.
* @param t The next transformation
*/
void setNext(Transformation t);
/**
* Validate that the RSV bit(s) required by this transformation are not
* being used by another extension. The implementation is expected to set
* any bits it requires before passing the set of in-use bits to the next
* transformation.
*
* @param i The RSV bits marked as in use so far as an int in the
* range zero to seven with RSV1 as the MSB and RSV3 as the
* LSB
*
* @return <code>true</code> if the combination of RSV bits used by the
* transformations in the pipeline do not conflict otherwise
* <code>false</code>
*/
boolean validateRsvBits(int i);
/**
* Obtain the extension that describes the information to be returned to the
* client.
*
* @return The extension information that describes the parameters that have
* been agreed for this transformation
*/
Extension getExtensionResponse();
/**
* Obtain more input data.
*
* @param opCode The opcode for the frame currently being processed
* @param fin Is this the final frame in this WebSocket message?
* @param rsv The reserved bits for the frame currently being
* processed
* @param dest The buffer in which the data is to be written
*
* @return The result of trying to read more data from the transform
*
* @throws IOException If an I/O error occurs while reading data from the
* transform
*/
TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) throws IOException;
/**
* Validates the RSV and opcode combination (assumed to have been extracted
* from a WebSocket Frame) for this extension. The implementation is
* expected to unset any RSV bits it has validated before passing the
* remaining RSV bits to the next transformation in the pipeline.
*
* @param rsv The RSV bits received as an int in the range zero to
* seven with RSV1 as the MSB and RSV3 as the LSB
* @param opCode The opCode received
*
* @return <code>true</code> if the RSV is valid otherwise
* <code>false</code>
*/
boolean validateRsv(int rsv, byte opCode);
/**
* Takes the provided list of messages, transforms them, passes the
* transformed list on to the next transformation (if any) and then returns
* the resulting list of message parts after all of the transformations have
* been applied.
*
* @param messageParts The list of messages to be transformed
*
* @return The list of messages after this any any subsequent
* transformations have been applied. The size of the returned list
* may be bigger or smaller than the size of the input list
*/
List<MessagePart> sendMessagePart(List<MessagePart> messageParts);
/**
* Clean-up any resources that were used by the transformation.
*/
void close();
}

View File

@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.util.List;
import javax.websocket.Extension;
import org.apache.tomcat.util.res.StringManager;
public class TransformationFactory {
private static final StringManager sm = StringManager.getManager(TransformationFactory.class);
private static final TransformationFactory factory = new TransformationFactory();
private TransformationFactory() {
// Hide default constructor
}
public static TransformationFactory getInstance() {
return factory;
}
public Transformation create(String name, List<List<Extension.Parameter>> preferences,
boolean isServer) {
if (PerMessageDeflate.NAME.equals(name)) {
return PerMessageDeflate.negotiate(preferences, isServer);
}
if (Constants.ALLOW_UNSUPPORTED_EXTENSIONS) {
return null;
} else {
throw new IllegalArgumentException(
sm.getString("transformerFactory.unsupportedExtension", name));
}
}
}

View File

@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
public enum TransformationResult {
/**
* The end of the available data was reached before the WebSocket frame was
* completely read.
*/
UNDERFLOW,
/**
* The provided destination buffer was filled before all of the available
* data from the WebSocket frame could be processed.
*/
OVERFLOW,
/**
* The end of the WebSocket frame was reached and all the data from that
* frame processed into the provided destination buffer.
*/
END_OF_FRAME
}

View File

@@ -0,0 +1,666 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.InputStream;
import java.io.Reader;
import java.lang.reflect.GenericArrayType;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.websocket.CloseReason.CloseCode;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.Decoder;
import javax.websocket.Decoder.Binary;
import javax.websocket.Decoder.BinaryStream;
import javax.websocket.Decoder.Text;
import javax.websocket.Decoder.TextStream;
import javax.websocket.DeploymentException;
import javax.websocket.Encoder;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.MessageHandler;
import javax.websocket.PongMessage;
import javax.websocket.Session;
import org.apache.tomcat.util.res.StringManager;
import nginx.unit.websocket.pojo.PojoMessageHandlerPartialBinary;
import nginx.unit.websocket.pojo.PojoMessageHandlerWholeBinary;
import nginx.unit.websocket.pojo.PojoMessageHandlerWholeText;
/**
* Utility class for internal use only within the
* {@link nginx.unit.websocket} package.
*/
public class Util {
private static final StringManager sm = StringManager.getManager(Util.class);
private static final Queue<SecureRandom> randoms =
new ConcurrentLinkedQueue<>();
private Util() {
// Hide default constructor
}
static boolean isControl(byte opCode) {
return (opCode & 0x08) != 0;
}
static boolean isText(byte opCode) {
return opCode == Constants.OPCODE_TEXT;
}
static boolean isContinuation(byte opCode) {
return opCode == Constants.OPCODE_CONTINUATION;
}
static CloseCode getCloseCode(int code) {
if (code > 2999 && code < 5000) {
return CloseCodes.getCloseCode(code);
}
switch (code) {
case 1000:
return CloseCodes.NORMAL_CLOSURE;
case 1001:
return CloseCodes.GOING_AWAY;
case 1002:
return CloseCodes.PROTOCOL_ERROR;
case 1003:
return CloseCodes.CANNOT_ACCEPT;
case 1004:
// Should not be used in a close frame
// return CloseCodes.RESERVED;
return CloseCodes.PROTOCOL_ERROR;
case 1005:
// Should not be used in a close frame
// return CloseCodes.NO_STATUS_CODE;
return CloseCodes.PROTOCOL_ERROR;
case 1006:
// Should not be used in a close frame
// return CloseCodes.CLOSED_ABNORMALLY;
return CloseCodes.PROTOCOL_ERROR;
case 1007:
return CloseCodes.NOT_CONSISTENT;
case 1008:
return CloseCodes.VIOLATED_POLICY;
case 1009:
return CloseCodes.TOO_BIG;
case 1010:
return CloseCodes.NO_EXTENSION;
case 1011:
return CloseCodes.UNEXPECTED_CONDITION;
case 1012:
// Not in RFC6455
// return CloseCodes.SERVICE_RESTART;
return CloseCodes.PROTOCOL_ERROR;
case 1013:
// Not in RFC6455
// return CloseCodes.TRY_AGAIN_LATER;
return CloseCodes.PROTOCOL_ERROR;
case 1015:
// Should not be used in a close frame
// return CloseCodes.TLS_HANDSHAKE_FAILURE;
return CloseCodes.PROTOCOL_ERROR;
default:
return CloseCodes.PROTOCOL_ERROR;
}
}
static byte[] generateMask() {
// SecureRandom is not thread-safe so need to make sure only one thread
// uses it at a time. In theory, the pool could grow to the same size
// as the number of request processing threads. In reality it will be
// a lot smaller.
// Get a SecureRandom from the pool
SecureRandom sr = randoms.poll();
// If one isn't available, generate a new one
if (sr == null) {
try {
sr = SecureRandom.getInstance("SHA1PRNG");
} catch (NoSuchAlgorithmException e) {
// Fall back to platform default
sr = new SecureRandom();
}
}
// Generate the mask
byte[] result = new byte[4];
sr.nextBytes(result);
// Put the SecureRandom back in the poll
randoms.add(sr);
return result;
}
static Class<?> getMessageType(MessageHandler listener) {
return Util.getGenericType(MessageHandler.class,
listener.getClass()).getClazz();
}
private static Class<?> getDecoderType(Class<? extends Decoder> decoder) {
return Util.getGenericType(Decoder.class, decoder).getClazz();
}
static Class<?> getEncoderType(Class<? extends Encoder> encoder) {
return Util.getGenericType(Encoder.class, encoder).getClazz();
}
private static <T> TypeResult getGenericType(Class<T> type,
Class<? extends T> clazz) {
// Look to see if this class implements the interface of interest
// Get all the interfaces
Type[] interfaces = clazz.getGenericInterfaces();
for (Type iface : interfaces) {
// Only need to check interfaces that use generics
if (iface instanceof ParameterizedType) {
ParameterizedType pi = (ParameterizedType) iface;
// Look for the interface of interest
if (pi.getRawType() instanceof Class) {
if (type.isAssignableFrom((Class<?>) pi.getRawType())) {
return getTypeParameter(
clazz, pi.getActualTypeArguments()[0]);
}
}
}
}
// Interface not found on this class. Look at the superclass.
@SuppressWarnings("unchecked")
Class<? extends T> superClazz =
(Class<? extends T>) clazz.getSuperclass();
if (superClazz == null) {
// Finished looking up the class hierarchy without finding anything
return null;
}
TypeResult superClassTypeResult = getGenericType(type, superClazz);
int dimension = superClassTypeResult.getDimension();
if (superClassTypeResult.getIndex() == -1 && dimension == 0) {
// Superclass implements interface and defines explicit type for
// the interface of interest
return superClassTypeResult;
}
if (superClassTypeResult.getIndex() > -1) {
// Superclass implements interface and defines unknown type for
// the interface of interest
// Map that unknown type to the generic types defined in this class
ParameterizedType superClassType =
(ParameterizedType) clazz.getGenericSuperclass();
TypeResult result = getTypeParameter(clazz,
superClassType.getActualTypeArguments()[
superClassTypeResult.getIndex()]);
result.incrementDimension(superClassTypeResult.getDimension());
if (result.getClazz() != null && result.getDimension() > 0) {
superClassTypeResult = result;
} else {
return result;
}
}
if (superClassTypeResult.getDimension() > 0) {
StringBuilder className = new StringBuilder();
for (int i = 0; i < dimension; i++) {
className.append('[');
}
className.append('L');
className.append(superClassTypeResult.getClazz().getCanonicalName());
className.append(';');
Class<?> arrayClazz;
try {
arrayClazz = Class.forName(className.toString());
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException(e);
}
return new TypeResult(arrayClazz, -1, 0);
}
// Error will be logged further up the call stack
return null;
}
/*
* For a generic parameter, return either the Class used or if the type
* is unknown, the index for the type in definition of the class
*/
private static TypeResult getTypeParameter(Class<?> clazz, Type argType) {
if (argType instanceof Class<?>) {
return new TypeResult((Class<?>) argType, -1, 0);
} else if (argType instanceof ParameterizedType) {
return new TypeResult((Class<?>)((ParameterizedType) argType).getRawType(), -1, 0);
} else if (argType instanceof GenericArrayType) {
Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType();
TypeResult result = getTypeParameter(clazz, arrayElementType);
result.incrementDimension(1);
return result;
} else {
TypeVariable<?>[] tvs = clazz.getTypeParameters();
for (int i = 0; i < tvs.length; i++) {
if (tvs[i].equals(argType)) {
return new TypeResult(null, i, 0);
}
}
return null;
}
}
public static boolean isPrimitive(Class<?> clazz) {
if (clazz.isPrimitive()) {
return true;
} else if(clazz.equals(Boolean.class) ||
clazz.equals(Byte.class) ||
clazz.equals(Character.class) ||
clazz.equals(Double.class) ||
clazz.equals(Float.class) ||
clazz.equals(Integer.class) ||
clazz.equals(Long.class) ||
clazz.equals(Short.class)) {
return true;
}
return false;
}
public static Object coerceToType(Class<?> type, String value) {
if (type.equals(String.class)) {
return value;
} else if (type.equals(boolean.class) || type.equals(Boolean.class)) {
return Boolean.valueOf(value);
} else if (type.equals(byte.class) || type.equals(Byte.class)) {
return Byte.valueOf(value);
} else if (type.equals(char.class) || type.equals(Character.class)) {
return Character.valueOf(value.charAt(0));
} else if (type.equals(double.class) || type.equals(Double.class)) {
return Double.valueOf(value);
} else if (type.equals(float.class) || type.equals(Float.class)) {
return Float.valueOf(value);
} else if (type.equals(int.class) || type.equals(Integer.class)) {
return Integer.valueOf(value);
} else if (type.equals(long.class) || type.equals(Long.class)) {
return Long.valueOf(value);
} else if (type.equals(short.class) || type.equals(Short.class)) {
return Short.valueOf(value);
} else {
throw new IllegalArgumentException(sm.getString(
"util.invalidType", value, type.getName()));
}
}
public static List<DecoderEntry> getDecoders(
List<Class<? extends Decoder>> decoderClazzes)
throws DeploymentException{
List<DecoderEntry> result = new ArrayList<>();
if (decoderClazzes != null) {
for (Class<? extends Decoder> decoderClazz : decoderClazzes) {
// Need to instantiate decoder to ensure it is valid and that
// deployment can be failed if it is not
@SuppressWarnings("unused")
Decoder instance;
try {
instance = decoderClazz.getConstructor().newInstance();
} catch (ReflectiveOperationException e) {
throw new DeploymentException(
sm.getString("pojoMethodMapping.invalidDecoder",
decoderClazz.getName()), e);
}
DecoderEntry entry = new DecoderEntry(
Util.getDecoderType(decoderClazz), decoderClazz);
result.add(entry);
}
}
return result;
}
static Set<MessageHandlerResult> getMessageHandlers(Class<?> target,
MessageHandler listener, EndpointConfig endpointConfig,
Session session) {
// Will never be more than 2 types
Set<MessageHandlerResult> results = new HashSet<>(2);
// Simple cases - handlers already accepts one of the types expected by
// the frame handling code
if (String.class.isAssignableFrom(target)) {
MessageHandlerResult result =
new MessageHandlerResult(listener,
MessageHandlerResultType.TEXT);
results.add(result);
} else if (ByteBuffer.class.isAssignableFrom(target)) {
MessageHandlerResult result =
new MessageHandlerResult(listener,
MessageHandlerResultType.BINARY);
results.add(result);
} else if (PongMessage.class.isAssignableFrom(target)) {
MessageHandlerResult result =
new MessageHandlerResult(listener,
MessageHandlerResultType.PONG);
results.add(result);
// Handler needs wrapping and optional decoder to convert it to one of
// the types expected by the frame handling code
} else if (byte[].class.isAssignableFrom(target)) {
boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass());
MessageHandlerResult result = new MessageHandlerResult(
whole ? new PojoMessageHandlerWholeBinary(listener,
getOnMessageMethod(listener), session,
endpointConfig, matchDecoders(target, endpointConfig, true),
new Object[1], 0, true, -1, false, -1) :
new PojoMessageHandlerPartialBinary(listener,
getOnMessagePartialMethod(listener), session,
new Object[2], 0, true, 1, -1, -1),
MessageHandlerResultType.BINARY);
results.add(result);
} else if (InputStream.class.isAssignableFrom(target)) {
MessageHandlerResult result = new MessageHandlerResult(
new PojoMessageHandlerWholeBinary(listener,
getOnMessageMethod(listener), session,
endpointConfig, matchDecoders(target, endpointConfig, true),
new Object[1], 0, true, -1, true, -1),
MessageHandlerResultType.BINARY);
results.add(result);
} else if (Reader.class.isAssignableFrom(target)) {
MessageHandlerResult result = new MessageHandlerResult(
new PojoMessageHandlerWholeText(listener,
getOnMessageMethod(listener), session,
endpointConfig, matchDecoders(target, endpointConfig, false),
new Object[1], 0, true, -1, -1),
MessageHandlerResultType.TEXT);
results.add(result);
} else {
// Handler needs wrapping and requires decoder to convert it to one
// of the types expected by the frame handling code
DecoderMatch decoderMatch = matchDecoders(target, endpointConfig);
Method m = getOnMessageMethod(listener);
if (decoderMatch.getBinaryDecoders().size() > 0) {
MessageHandlerResult result = new MessageHandlerResult(
new PojoMessageHandlerWholeBinary(listener, m, session,
endpointConfig,
decoderMatch.getBinaryDecoders(), new Object[1],
0, false, -1, false, -1),
MessageHandlerResultType.BINARY);
results.add(result);
}
if (decoderMatch.getTextDecoders().size() > 0) {
MessageHandlerResult result = new MessageHandlerResult(
new PojoMessageHandlerWholeText(listener, m, session,
endpointConfig,
decoderMatch.getTextDecoders(), new Object[1],
0, false, -1, -1),
MessageHandlerResultType.TEXT);
results.add(result);
}
}
if (results.size() == 0) {
throw new IllegalArgumentException(
sm.getString("wsSession.unknownHandler", listener, target));
}
return results;
}
private static List<Class<? extends Decoder>> matchDecoders(Class<?> target,
EndpointConfig endpointConfig, boolean binary) {
DecoderMatch decoderMatch = matchDecoders(target, endpointConfig);
if (binary) {
if (decoderMatch.getBinaryDecoders().size() > 0) {
return decoderMatch.getBinaryDecoders();
}
} else if (decoderMatch.getTextDecoders().size() > 0) {
return decoderMatch.getTextDecoders();
}
return null;
}
private static DecoderMatch matchDecoders(Class<?> target,
EndpointConfig endpointConfig) {
DecoderMatch decoderMatch;
try {
List<Class<? extends Decoder>> decoders =
endpointConfig.getDecoders();
List<DecoderEntry> decoderEntries = getDecoders(decoders);
decoderMatch = new DecoderMatch(target, decoderEntries);
} catch (DeploymentException e) {
throw new IllegalArgumentException(e);
}
return decoderMatch;
}
public static void parseExtensionHeader(List<Extension> extensions,
String header) {
// The relevant ABNF for the Sec-WebSocket-Extensions is as follows:
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ; When using the quoted-string syntax variant, the value
// ; after quoted-string unescaping MUST conform to the
// ; 'token' ABNF.
//
// The limiting of parameter values to tokens or "quoted tokens" makes
// the parsing of the header significantly simpler and allows a number
// of short-cuts to be taken.
// Step one, split the header into individual extensions using ',' as a
// separator
String unparsedExtensions[] = header.split(",");
for (String unparsedExtension : unparsedExtensions) {
// Step two, split the extension into the registered name and
// parameter/value pairs using ';' as a separator
String unparsedParameters[] = unparsedExtension.split(";");
WsExtension extension = new WsExtension(unparsedParameters[0].trim());
for (int i = 1; i < unparsedParameters.length; i++) {
int equalsPos = unparsedParameters[i].indexOf('=');
String name;
String value;
if (equalsPos == -1) {
name = unparsedParameters[i].trim();
value = null;
} else {
name = unparsedParameters[i].substring(0, equalsPos).trim();
value = unparsedParameters[i].substring(equalsPos + 1).trim();
int len = value.length();
if (len > 1) {
if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') {
value = value.substring(1, value.length() - 1);
}
}
}
// Make sure value doesn't contain any of the delimiters since
// that would indicate something went wrong
if (containsDelims(name) || containsDelims(value)) {
throw new IllegalArgumentException(sm.getString(
"util.notToken", name, value));
}
if (value != null &&
(value.indexOf(',') > -1 || value.indexOf(';') > -1 ||
value.indexOf('\"') > -1 || value.indexOf('=') > -1)) {
throw new IllegalArgumentException(sm.getString("", value));
}
extension.addParameter(new WsExtensionParameter(name, value));
}
extensions.add(extension);
}
}
private static boolean containsDelims(String input) {
if (input == null || input.length() == 0) {
return false;
}
for (char c : input.toCharArray()) {
switch (c) {
case ',':
case ';':
case '\"':
case '=':
return true;
default:
// NO_OP
}
}
return false;
}
private static Method getOnMessageMethod(MessageHandler listener) {
try {
return listener.getClass().getMethod("onMessage", Object.class);
} catch (NoSuchMethodException | SecurityException e) {
throw new IllegalArgumentException(
sm.getString("util.invalidMessageHandler"), e);
}
}
private static Method getOnMessagePartialMethod(MessageHandler listener) {
try {
return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE);
} catch (NoSuchMethodException | SecurityException e) {
throw new IllegalArgumentException(
sm.getString("util.invalidMessageHandler"), e);
}
}
public static class DecoderMatch {
private final List<Class<? extends Decoder>> textDecoders =
new ArrayList<>();
private final List<Class<? extends Decoder>> binaryDecoders =
new ArrayList<>();
private final Class<?> target;
public DecoderMatch(Class<?> target, List<DecoderEntry> decoderEntries) {
this.target = target;
for (DecoderEntry decoderEntry : decoderEntries) {
if (decoderEntry.getClazz().isAssignableFrom(target)) {
if (Binary.class.isAssignableFrom(
decoderEntry.getDecoderClazz())) {
binaryDecoders.add(decoderEntry.getDecoderClazz());
// willDecode() method means this decoder may or may not
// decode a message so need to carry on checking for
// other matches
} else if (BinaryStream.class.isAssignableFrom(
decoderEntry.getDecoderClazz())) {
binaryDecoders.add(decoderEntry.getDecoderClazz());
// Stream decoders have to process the message so no
// more decoders can be matched
break;
} else if (Text.class.isAssignableFrom(
decoderEntry.getDecoderClazz())) {
textDecoders.add(decoderEntry.getDecoderClazz());
// willDecode() method means this decoder may or may not
// decode a message so need to carry on checking for
// other matches
} else if (TextStream.class.isAssignableFrom(
decoderEntry.getDecoderClazz())) {
textDecoders.add(decoderEntry.getDecoderClazz());
// Stream decoders have to process the message so no
// more decoders can be matched
break;
} else {
throw new IllegalArgumentException(
sm.getString("util.unknownDecoderType"));
}
}
}
}
public List<Class<? extends Decoder>> getTextDecoders() {
return textDecoders;
}
public List<Class<? extends Decoder>> getBinaryDecoders() {
return binaryDecoders;
}
public Class<?> getTarget() {
return target;
}
public boolean hasMatches() {
return (textDecoders.size() > 0) || (binaryDecoders.size() > 0);
}
}
private static class TypeResult {
private final Class<?> clazz;
private final int index;
private int dimension;
public TypeResult(Class<?> clazz, int index, int dimension) {
this.clazz= clazz;
this.index = index;
this.dimension = dimension;
}
public Class<?> getClazz() {
return clazz;
}
public int getIndex() {
return index;
}
public int getDimension() {
return dimension;
}
public void incrementDimension(int inc) {
dimension += inc;
}
}
}

View File

@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import javax.websocket.MessageHandler;
public interface WrappedMessageHandler {
long getMaxMessageSize();
MessageHandler getWrappedHandler();
}

View File

@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import javax.websocket.ContainerProvider;
import javax.websocket.WebSocketContainer;
public class WsContainerProvider extends ContainerProvider {
@Override
protected WebSocketContainer getContainer() {
return new WsWebSocketContainer();
}
}

View File

@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.util.ArrayList;
import java.util.List;
import javax.websocket.Extension;
public class WsExtension implements Extension {
private final String name;
private final List<Parameter> parameters = new ArrayList<>();
WsExtension(String name) {
this.name = name;
}
void addParameter(Parameter parameter) {
parameters.add(parameter);
}
@Override
public String getName() {
return name;
}
@Override
public List<Parameter> getParameters() {
return parameters;
}
}

View File

@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import javax.websocket.Extension.Parameter;
public class WsExtensionParameter implements Parameter {
private final String name;
private final String value;
WsExtensionParameter(String name, String value) {
this.name = name;
this.value = value;
}
@Override
public String getName() {
return name;
}
@Override
public String getValue() {
return value;
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.res.StringManager;
public class WsFrameClient extends WsFrameBase {
private final Log log = LogFactory.getLog(WsFrameClient.class); // must not be static
private static final StringManager sm = StringManager.getManager(WsFrameClient.class);
private final AsyncChannelWrapper channel;
private final CompletionHandler<Integer, Void> handler;
// Not final as it may need to be re-sized
private volatile ByteBuffer response;
public WsFrameClient(ByteBuffer response, AsyncChannelWrapper channel, WsSession wsSession,
Transformation transformation) {
super(wsSession, transformation);
this.response = response;
this.channel = channel;
this.handler = new WsFrameClientCompletionHandler();
}
void startInputProcessing() {
try {
processSocketRead();
} catch (IOException e) {
close(e);
}
}
private void processSocketRead() throws IOException {
while (true) {
switch (getReadState()) {
case WAITING:
if (!changeReadState(ReadState.WAITING, ReadState.PROCESSING)) {
continue;
}
while (response.hasRemaining()) {
if (isSuspended()) {
if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) {
continue;
}
// There is still data available in the response buffer
// Return here so that the response buffer will not be
// cleared and there will be no data read from the
// socket. Thus when the read operation is resumed first
// the data left in the response buffer will be consumed
// and then a new socket read will be performed
return;
}
inputBuffer.mark();
inputBuffer.position(inputBuffer.limit()).limit(inputBuffer.capacity());
int toCopy = Math.min(response.remaining(), inputBuffer.remaining());
// Copy remaining bytes read in HTTP phase to input buffer used by
// frame processing
int orgLimit = response.limit();
response.limit(response.position() + toCopy);
inputBuffer.put(response);
response.limit(orgLimit);
inputBuffer.limit(inputBuffer.position()).reset();
// Process the data we have
processInputBuffer();
}
response.clear();
// Get some more data
if (isOpen()) {
channel.read(response, null, handler);
} else {
changeReadState(ReadState.CLOSING);
}
return;
case SUSPENDING_WAIT:
if (!changeReadState(ReadState.SUSPENDING_WAIT, ReadState.SUSPENDED)) {
continue;
}
return;
default:
throw new IllegalStateException(
sm.getString("wsFrameServer.illegalReadState", getReadState()));
}
}
}
private final void close(Throwable t) {
changeReadState(ReadState.CLOSING);
CloseReason cr;
if (t instanceof WsIOException) {
cr = ((WsIOException) t).getCloseReason();
} else {
cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, t.getMessage());
}
try {
wsSession.close(cr);
} catch (IOException ignore) {
// Ignore
}
}
@Override
protected boolean isMasked() {
// Data is from the server so it is not masked
return false;
}
@Override
protected Log getLog() {
return log;
}
private class WsFrameClientCompletionHandler implements CompletionHandler<Integer, Void> {
@Override
public void completed(Integer result, Void attachment) {
if (result.intValue() == -1) {
// BZ 57762. A dropped connection will get reported as EOF
// rather than as an error so handle it here.
if (isOpen()) {
// No close frame was received
close(new EOFException());
}
// No data to process
return;
}
response.flip();
doResumeProcessing(true);
}
@Override
public void failed(Throwable exc, Void attachment) {
if (exc instanceof ReadBufferOverflowException) {
// response will be empty if this exception is thrown
response = ByteBuffer
.allocate(((ReadBufferOverflowException) exc).getMinBufferSize());
response.flip();
doResumeProcessing(false);
} else {
close(exc);
}
}
private void doResumeProcessing(boolean checkOpenOnError) {
while (true) {
switch (getReadState()) {
case PROCESSING:
if (!changeReadState(ReadState.PROCESSING, ReadState.WAITING)) {
continue;
}
resumeProcessing(checkOpenOnError);
return;
case SUSPENDING_PROCESS:
if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) {
continue;
}
return;
default:
throw new IllegalStateException(
sm.getString("wsFrame.illegalReadState", getReadState()));
}
}
}
}
@Override
protected void resumeProcessing() {
resumeProcessing(true);
}
private void resumeProcessing(boolean checkOpenOnError) {
try {
processSocketRead();
} catch (IOException e) {
if (checkOpenOnError) {
// Only send a close message on an IOException if the client
// has not yet received a close control message from the server
// as the IOException may be in response to the client
// continuing to send a message after the server sent a close
// control message.
if (isOpen()) {
if (log.isDebugEnabled()) {
log.debug(sm.getString("wsFrameClient.ioe"), e);
}
close(e);
}
} else {
close(e);
}
}
}
}

View File

@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import javax.websocket.HandshakeResponse;
import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
/**
* Represents the response to a WebSocket handshake.
*/
public class WsHandshakeResponse implements HandshakeResponse {
private final Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>();
public WsHandshakeResponse() {
}
public WsHandshakeResponse(Map<String,List<String>> headers) {
for (Entry<String,List<String>> entry : headers.entrySet()) {
if (this.headers.containsKey(entry.getKey())) {
this.headers.get(entry.getKey()).addAll(entry.getValue());
} else {
List<String> values = new ArrayList<>(entry.getValue());
this.headers.put(entry.getKey(), values);
}
}
}
@Override
public Map<String,List<String>> getHeaders() {
return headers;
}
}

View File

@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import javax.websocket.CloseReason;
/**
* Allows the WebSocket implementation to throw an {@link IOException} that
* includes a {@link CloseReason} specific to the error that can be passed back
* to the client.
*/
public class WsIOException extends IOException {
private static final long serialVersionUID = 1L;
private final CloseReason closeReason;
public WsIOException(CloseReason closeReason) {
this.closeReason = closeReason;
}
public CloseReason getCloseReason() {
return closeReason;
}
}

View File

@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.nio.ByteBuffer;
import javax.websocket.PongMessage;
public class WsPongMessage implements PongMessage {
private final ByteBuffer applicationData;
public WsPongMessage(ByteBuffer applicationData) {
byte[] dst = new byte[applicationData.limit()];
applicationData.get(dst);
this.applicationData = ByteBuffer.wrap(dst);
}
@Override
public ByteBuffer getApplicationData() {
return applicationData;
}
}

View File

@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.nio.ByteBuffer;
import java.util.concurrent.Future;
import javax.websocket.RemoteEndpoint;
import javax.websocket.SendHandler;
public class WsRemoteEndpointAsync extends WsRemoteEndpointBase
implements RemoteEndpoint.Async {
WsRemoteEndpointAsync(WsRemoteEndpointImplBase base) {
super(base);
}
@Override
public long getSendTimeout() {
return base.getSendTimeout();
}
@Override
public void setSendTimeout(long timeout) {
base.setSendTimeout(timeout);
}
@Override
public void sendText(String text, SendHandler completion) {
base.sendStringByCompletion(text, completion);
}
@Override
public Future<Void> sendText(String text) {
return base.sendStringByFuture(text);
}
@Override
public Future<Void> sendBinary(ByteBuffer data) {
return base.sendBytesByFuture(data);
}
@Override
public void sendBinary(ByteBuffer data, SendHandler completion) {
base.sendBytesByCompletion(data, completion);
}
@Override
public Future<Void> sendObject(Object obj) {
return base.sendObjectByFuture(obj);
}
@Override
public void sendObject(Object obj, SendHandler completion) {
base.sendObjectByCompletion(obj, completion);
}
}

View File

@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import java.nio.ByteBuffer;
import javax.websocket.RemoteEndpoint;
public abstract class WsRemoteEndpointBase implements RemoteEndpoint {
protected final WsRemoteEndpointImplBase base;
WsRemoteEndpointBase(WsRemoteEndpointImplBase base) {
this.base = base;
}
@Override
public final void setBatchingAllowed(boolean batchingAllowed) throws IOException {
base.setBatchingAllowed(batchingAllowed);
}
@Override
public final boolean getBatchingAllowed() {
return base.getBatchingAllowed();
}
@Override
public final void flushBatch() throws IOException {
base.flushBatch();
}
@Override
public final void sendPing(ByteBuffer applicationData) throws IOException,
IllegalArgumentException {
base.sendPing(applicationData);
}
@Override
public final void sendPong(ByteBuffer applicationData) throws IOException,
IllegalArgumentException {
base.sendPong(applicationData);
}
}

View File

@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Writer;
import java.nio.ByteBuffer;
import javax.websocket.EncodeException;
import javax.websocket.RemoteEndpoint;
public class WsRemoteEndpointBasic extends WsRemoteEndpointBase
implements RemoteEndpoint.Basic {
WsRemoteEndpointBasic(WsRemoteEndpointImplBase base) {
super(base);
}
@Override
public void sendText(String text) throws IOException {
base.sendString(text);
}
@Override
public void sendBinary(ByteBuffer data) throws IOException {
base.sendBytes(data);
}
@Override
public void sendText(String fragment, boolean isLast) throws IOException {
base.sendPartialString(fragment, isLast);
}
@Override
public void sendBinary(ByteBuffer partialByte, boolean isLast)
throws IOException {
base.sendPartialBytes(partialByte, isLast);
}
@Override
public OutputStream getSendStream() throws IOException {
return base.getSendStream();
}
@Override
public Writer getSendWriter() throws IOException {
return base.getSendWriter();
}
@Override
public void sendObject(Object o) throws IOException, EncodeException {
base.sendObject(o);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.websocket.SendHandler;
import javax.websocket.SendResult;
public class WsRemoteEndpointImplClient extends WsRemoteEndpointImplBase {
private final AsyncChannelWrapper channel;
public WsRemoteEndpointImplClient(AsyncChannelWrapper channel) {
this.channel = channel;
}
@Override
protected boolean isMasked() {
return true;
}
@Override
protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry,
ByteBuffer... data) {
long timeout;
for (ByteBuffer byteBuffer : data) {
if (blockingWriteTimeoutExpiry == -1) {
timeout = getSendTimeout();
if (timeout < 1) {
timeout = Long.MAX_VALUE;
}
} else {
timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis();
if (timeout < 0) {
SendResult sr = new SendResult(new IOException("Blocking write timeout"));
handler.onResult(sr);
}
}
try {
channel.write(byteBuffer).get(timeout, TimeUnit.MILLISECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
handler.onResult(new SendResult(e));
return;
}
}
handler.onResult(SENDRESULT_OK);
}
@Override
protected void doClose() {
channel.close();
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
/**
* Internal implementation constants.
*/
public class Constants {
public static final String POJO_PATH_PARAM_KEY =
"nginx.unit.websocket.pojo.PojoEndpoint.pathParams";
public static final String POJO_METHOD_MAPPING_KEY =
"nginx.unit.websocket.pojo.PojoEndpoint.methodMapping";
private Constants() {
// Hide default constructor
}
}

View File

@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pojoEndpointBase.closeSessionFail=Failed to close WebSocket session during error handling
pojoEndpointBase.onCloseFail=Failed to call onClose method of POJO end point for POJO of type [{0}]
pojoEndpointBase.onError=No error handling configured for [{0}] and the following error occurred
pojoEndpointBase.onErrorFail=Failed to call onError method of POJO end point for POJO of type [{0}]
pojoEndpointBase.onOpenFail=Failed to call onOpen method of POJO end point for POJO of type [{0}]
pojoEndpointServer.getPojoInstanceFail=Failed to create instance of POJO of type [{0}]
pojoMethodMapping.decodePathParamFail=Failed to decode path parameter value [{0}] to expected type [{1}]
pojoMethodMapping.duplicateAnnotation=Duplicate annotations [{0}] present on class [{1}]
pojoMethodMapping.duplicateLastParam=Multiple boolean (last) parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.duplicateMessageParam=Multiple message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.duplicatePongMessageParam=Multiple PongMessage parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.duplicateSessionParam=Multiple session parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.invalidDecoder=The specified decoder of type [{0}] could not be instantiated
pojoMethodMapping.invalidPathParamType=Parameters annotated with @PathParam may only be Strings, Java primitives or a boxed version thereof
pojoMethodMapping.methodNotPublic=The annotated method [{0}] is not public
pojoMethodMapping.noPayload=No payload parameter present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.onErrorNoThrowable=No Throwable parameter was present on the method [{0}] of class [{1}] that was annotated with OnError
pojoMethodMapping.paramWithoutAnnotation=A parameter of type [{0}] was found on method[{1}] of class [{2}] that did not have a @PathParam annotation
pojoMethodMapping.partialInputStream=Invalid InputStream and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.partialObject=Invalid Object and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.partialPong=Invalid PongMessage and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.partialReader=Invalid Reader and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMethodMapping.pongWithPayload=Invalid PongMessage and Message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
pojoMessageHandlerWhole.decodeIoFail=IO error while decoding message
pojoMessageHandlerWhole.maxBufferSize=The maximum supported message size for this implementation is Integer.MAX_VALUE

View File

@@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.Map;
import java.util.Set;
import javax.websocket.CloseReason;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.ExceptionUtils;
import org.apache.tomcat.util.res.StringManager;
/**
* Base implementation (client and server have different concrete
* implementations) of the wrapper that converts a POJO instance into a
* WebSocket endpoint instance.
*/
public abstract class PojoEndpointBase extends Endpoint {
private final Log log = LogFactory.getLog(PojoEndpointBase.class); // must not be static
private static final StringManager sm = StringManager.getManager(PojoEndpointBase.class);
private Object pojo;
private Map<String,String> pathParameters;
private PojoMethodMapping methodMapping;
protected final void doOnOpen(Session session, EndpointConfig config) {
PojoMethodMapping methodMapping = getMethodMapping();
Object pojo = getPojo();
Map<String,String> pathParameters = getPathParameters();
// Add message handlers before calling onOpen since that may trigger a
// message which in turn could trigger a response and/or close the
// session
for (MessageHandler mh : methodMapping.getMessageHandlers(pojo,
pathParameters, session, config)) {
session.addMessageHandler(mh);
}
if (methodMapping.getOnOpen() != null) {
try {
methodMapping.getOnOpen().invoke(pojo,
methodMapping.getOnOpenArgs(
pathParameters, session, config));
} catch (IllegalAccessException e) {
// Reflection related problems
log.error(sm.getString(
"pojoEndpointBase.onOpenFail",
pojo.getClass().getName()), e);
handleOnOpenOrCloseError(session, e);
} catch (InvocationTargetException e) {
Throwable cause = e.getCause();
handleOnOpenOrCloseError(session, cause);
} catch (Throwable t) {
handleOnOpenOrCloseError(session, t);
}
}
}
private void handleOnOpenOrCloseError(Session session, Throwable t) {
// If really fatal - re-throw
ExceptionUtils.handleThrowable(t);
// Trigger the error handler and close the session
onError(session, t);
try {
session.close();
} catch (IOException ioe) {
log.warn(sm.getString("pojoEndpointBase.closeSessionFail"), ioe);
}
}
@Override
public final void onClose(Session session, CloseReason closeReason) {
if (methodMapping.getOnClose() != null) {
try {
methodMapping.getOnClose().invoke(pojo,
methodMapping.getOnCloseArgs(pathParameters, session, closeReason));
} catch (Throwable t) {
log.error(sm.getString("pojoEndpointBase.onCloseFail",
pojo.getClass().getName()), t);
handleOnOpenOrCloseError(session, t);
}
}
// Trigger the destroy method for any associated decoders
Set<MessageHandler> messageHandlers = session.getMessageHandlers();
for (MessageHandler messageHandler : messageHandlers) {
if (messageHandler instanceof PojoMessageHandlerWholeBase<?>) {
((PojoMessageHandlerWholeBase<?>) messageHandler).onClose();
}
}
}
@Override
public final void onError(Session session, Throwable throwable) {
if (methodMapping.getOnError() == null) {
log.error(sm.getString("pojoEndpointBase.onError",
pojo.getClass().getName()), throwable);
} else {
try {
methodMapping.getOnError().invoke(
pojo,
methodMapping.getOnErrorArgs(pathParameters, session,
throwable));
} catch (Throwable t) {
ExceptionUtils.handleThrowable(t);
log.error(sm.getString("pojoEndpointBase.onErrorFail",
pojo.getClass().getName()), t);
}
}
}
protected Object getPojo() { return pojo; }
protected void setPojo(Object pojo) { this.pojo = pojo; }
protected Map<String,String> getPathParameters() { return pathParameters; }
protected void setPathParameters(Map<String,String> pathParameters) {
this.pathParameters = pathParameters;
}
protected PojoMethodMapping getMethodMapping() { return methodMapping; }
protected void setMethodMapping(PojoMethodMapping methodMapping) {
this.methodMapping = methodMapping;
}
}

View File

@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.util.Collections;
import java.util.List;
import javax.websocket.Decoder;
import javax.websocket.DeploymentException;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
/**
* Wrapper class for instances of POJOs annotated with
* {@link javax.websocket.ClientEndpoint} so they appear as standard
* {@link javax.websocket.Endpoint} instances.
*/
public class PojoEndpointClient extends PojoEndpointBase {
public PojoEndpointClient(Object pojo,
List<Class<? extends Decoder>> decoders) throws DeploymentException {
setPojo(pojo);
setMethodMapping(
new PojoMethodMapping(pojo.getClass(), decoders, null));
setPathParameters(Collections.<String,String>emptyMap());
}
@Override
public void onOpen(Session session, EndpointConfig config) {
doOnOpen(session, config);
}
}

View File

@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.util.Map;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpointConfig;
import org.apache.tomcat.util.res.StringManager;
/**
* Wrapper class for instances of POJOs annotated with
* {@link javax.websocket.server.ServerEndpoint} so they appear as standard
* {@link javax.websocket.Endpoint} instances.
*/
public class PojoEndpointServer extends PojoEndpointBase {
private static final StringManager sm =
StringManager.getManager(PojoEndpointServer.class);
@Override
public void onOpen(Session session, EndpointConfig endpointConfig) {
ServerEndpointConfig sec = (ServerEndpointConfig) endpointConfig;
Object pojo;
try {
pojo = sec.getConfigurator().getEndpointInstance(
sec.getEndpointClass());
} catch (InstantiationException e) {
throw new IllegalArgumentException(sm.getString(
"pojoEndpointServer.getPojoInstanceFail",
sec.getEndpointClass().getName()), e);
}
setPojo(pojo);
@SuppressWarnings("unchecked")
Map<String,String> pathParameters =
(Map<String, String>) sec.getUserProperties().get(
Constants.POJO_PATH_PARAM_KEY);
setPathParameters(pathParameters);
PojoMethodMapping methodMapping =
(PojoMethodMapping) sec.getUserProperties().get(
Constants.POJO_METHOD_MAPPING_KEY);
setMethodMapping(methodMapping);
doOnOpen(session, endpointConfig);
}
}

View File

@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.io.IOException;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import javax.websocket.EncodeException;
import javax.websocket.MessageHandler;
import javax.websocket.RemoteEndpoint;
import javax.websocket.Session;
import org.apache.tomcat.util.ExceptionUtils;
import nginx.unit.websocket.WrappedMessageHandler;
/**
* Common implementation code for the POJO message handlers.
*
* @param <T> The type of message to handle
*/
public abstract class PojoMessageHandlerBase<T>
implements WrappedMessageHandler {
protected final Object pojo;
protected final Method method;
protected final Session session;
protected final Object[] params;
protected final int indexPayload;
protected final boolean convert;
protected final int indexSession;
protected final long maxMessageSize;
public PojoMessageHandlerBase(Object pojo, Method method,
Session session, Object[] params, int indexPayload, boolean convert,
int indexSession, long maxMessageSize) {
this.pojo = pojo;
this.method = method;
// TODO: The method should already be accessible here but the following
// code seems to be necessary in some as yet not fully understood cases.
try {
this.method.setAccessible(true);
} catch (Exception e) {
// It is better to make sure the method is accessible, but
// ignore exceptions and hope for the best
}
this.session = session;
this.params = params;
this.indexPayload = indexPayload;
this.convert = convert;
this.indexSession = indexSession;
this.maxMessageSize = maxMessageSize;
}
protected final void processResult(Object result) {
if (result == null) {
return;
}
RemoteEndpoint.Basic remoteEndpoint = session.getBasicRemote();
try {
if (result instanceof String) {
remoteEndpoint.sendText((String) result);
} else if (result instanceof ByteBuffer) {
remoteEndpoint.sendBinary((ByteBuffer) result);
} else if (result instanceof byte[]) {
remoteEndpoint.sendBinary(ByteBuffer.wrap((byte[]) result));
} else {
remoteEndpoint.sendObject(result);
}
} catch (IOException | EncodeException ioe) {
throw new IllegalStateException(ioe);
}
}
/**
* Expose the POJO if it is a message handler so the Session is able to
* match requests to remove handlers if the original handler has been
* wrapped.
*/
@Override
public final MessageHandler getWrappedHandler() {
if (pojo instanceof MessageHandler) {
return (MessageHandler) pojo;
} else {
return null;
}
}
@Override
public final long getMaxMessageSize() {
return maxMessageSize;
}
protected final void handlePojoMethodException(Throwable t) {
t = ExceptionUtils.unwrapInvocationTargetException(t);
ExceptionUtils.handleThrowable(t);
if (t instanceof RuntimeException) {
throw (RuntimeException) t;
} else {
throw new RuntimeException(t.getMessage(), t);
}
}
}

View File

@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import javax.websocket.DecodeException;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import nginx.unit.websocket.WsSession;
/**
* Common implementation code for the POJO partial message handlers. All
* the real work is done in this class and in the superclass.
*
* @param <T> The type of message to handle
*/
public abstract class PojoMessageHandlerPartialBase<T>
extends PojoMessageHandlerBase<T> implements MessageHandler.Partial<T> {
private final int indexBoolean;
public PojoMessageHandlerPartialBase(Object pojo, Method method,
Session session, Object[] params, int indexPayload,
boolean convert, int indexBoolean, int indexSession,
long maxMessageSize) {
super(pojo, method, session, params, indexPayload, convert,
indexSession, maxMessageSize);
this.indexBoolean = indexBoolean;
}
@Override
public final void onMessage(T message, boolean last) {
if (params.length == 1 && params[0] instanceof DecodeException) {
((WsSession) session).getLocal().onError(session,
(DecodeException) params[0]);
return;
}
Object[] parameters = params.clone();
if (indexBoolean != -1) {
parameters[indexBoolean] = Boolean.valueOf(last);
}
if (indexSession != -1) {
parameters[indexSession] = session;
}
if (convert) {
parameters[indexPayload] = ((ByteBuffer) message).array();
} else {
parameters[indexPayload] = message;
}
Object result = null;
try {
result = method.invoke(pojo, parameters);
} catch (IllegalAccessException | InvocationTargetException e) {
handlePojoMethodException(e);
}
processResult(result);
}
}

View File

@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import javax.websocket.Session;
/**
* ByteBuffer specific concrete implementation for handling partial messages.
*/
public class PojoMessageHandlerPartialBinary
extends PojoMessageHandlerPartialBase<ByteBuffer> {
public PojoMessageHandlerPartialBinary(Object pojo, Method method,
Session session, Object[] params, int indexPayload, boolean convert,
int indexBoolean, int indexSession, long maxMessageSize) {
super(pojo, method, session, params, indexPayload, convert, indexBoolean,
indexSession, maxMessageSize);
}
}

View File

@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.lang.reflect.Method;
import javax.websocket.Session;
/**
* Text specific concrete implementation for handling partial messages.
*/
public class PojoMessageHandlerPartialText
extends PojoMessageHandlerPartialBase<String> {
public PojoMessageHandlerPartialText(Object pojo, Method method,
Session session, Object[] params, int indexPayload, boolean convert,
int indexBoolean, int indexSession, long maxMessageSize) {
super(pojo, method, session, params, indexPayload, convert, indexBoolean,
indexSession, maxMessageSize);
}
}

View File

@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import javax.websocket.DecodeException;
import javax.websocket.MessageHandler;
import javax.websocket.Session;
import nginx.unit.websocket.WsSession;
/**
* Common implementation code for the POJO whole message handlers. All the real
* work is done in this class and in the superclass.
*
* @param <T> The type of message to handle
*/
public abstract class PojoMessageHandlerWholeBase<T>
extends PojoMessageHandlerBase<T> implements MessageHandler.Whole<T> {
public PojoMessageHandlerWholeBase(Object pojo, Method method,
Session session, Object[] params, int indexPayload,
boolean convert, int indexSession, long maxMessageSize) {
super(pojo, method, session, params, indexPayload, convert,
indexSession, maxMessageSize);
}
@Override
public final void onMessage(T message) {
if (params.length == 1 && params[0] instanceof DecodeException) {
((WsSession) session).getLocal().onError(session,
(DecodeException) params[0]);
return;
}
// Can this message be decoded?
Object payload;
try {
payload = decode(message);
} catch (DecodeException de) {
((WsSession) session).getLocal().onError(session, de);
return;
}
if (payload == null) {
// Not decoded. Convert if required.
if (convert) {
payload = convert(message);
} else {
payload = message;
}
}
Object[] parameters = params.clone();
if (indexSession != -1) {
parameters[indexSession] = session;
}
parameters[indexPayload] = payload;
Object result = null;
try {
result = method.invoke(pojo, parameters);
} catch (IllegalAccessException | InvocationTargetException e) {
handlePojoMethodException(e);
}
processResult(result);
}
protected Object convert(T message) {
return message;
}
protected abstract Object decode(T message) throws DecodeException;
protected abstract void onClose();
}

View File

@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import javax.websocket.DecodeException;
import javax.websocket.Decoder;
import javax.websocket.Decoder.Binary;
import javax.websocket.Decoder.BinaryStream;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
import org.apache.tomcat.util.res.StringManager;
/**
* ByteBuffer specific concrete implementation for handling whole messages.
*/
public class PojoMessageHandlerWholeBinary
extends PojoMessageHandlerWholeBase<ByteBuffer> {
private static final StringManager sm =
StringManager.getManager(PojoMessageHandlerWholeBinary.class);
private final List<Decoder> decoders = new ArrayList<>();
private final boolean isForInputStream;
public PojoMessageHandlerWholeBinary(Object pojo, Method method,
Session session, EndpointConfig config,
List<Class<? extends Decoder>> decoderClazzes, Object[] params,
int indexPayload, boolean convert, int indexSession,
boolean isForInputStream, long maxMessageSize) {
super(pojo, method, session, params, indexPayload, convert,
indexSession, maxMessageSize);
// Update binary text size handled by session
if (maxMessageSize > -1 && maxMessageSize > session.getMaxBinaryMessageBufferSize()) {
if (maxMessageSize > Integer.MAX_VALUE) {
throw new IllegalArgumentException(sm.getString(
"pojoMessageHandlerWhole.maxBufferSize"));
}
session.setMaxBinaryMessageBufferSize((int) maxMessageSize);
}
try {
if (decoderClazzes != null) {
for (Class<? extends Decoder> decoderClazz : decoderClazzes) {
if (Binary.class.isAssignableFrom(decoderClazz)) {
Binary<?> decoder = (Binary<?>) decoderClazz.getConstructor().newInstance();
decoder.init(config);
decoders.add(decoder);
} else if (BinaryStream.class.isAssignableFrom(
decoderClazz)) {
BinaryStream<?> decoder = (BinaryStream<?>)
decoderClazz.getConstructor().newInstance();
decoder.init(config);
decoders.add(decoder);
} else {
// Text decoder - ignore it
}
}
}
} catch (ReflectiveOperationException e) {
throw new IllegalArgumentException(e);
}
this.isForInputStream = isForInputStream;
}
@Override
protected Object decode(ByteBuffer message) throws DecodeException {
for (Decoder decoder : decoders) {
if (decoder instanceof Binary) {
if (((Binary<?>) decoder).willDecode(message)) {
return ((Binary<?>) decoder).decode(message);
}
} else {
byte[] array = new byte[message.limit() - message.position()];
message.get(array);
ByteArrayInputStream bais = new ByteArrayInputStream(array);
try {
return ((BinaryStream<?>) decoder).decode(bais);
} catch (IOException ioe) {
throw new DecodeException(message, sm.getString(
"pojoMessageHandlerWhole.decodeIoFail"), ioe);
}
}
}
return null;
}
@Override
protected Object convert(ByteBuffer message) {
byte[] array = new byte[message.remaining()];
message.get(array);
if (isForInputStream) {
return new ByteArrayInputStream(array);
} else {
return array;
}
}
@Override
protected void onClose() {
for (Decoder decoder : decoders) {
decoder.destroy();
}
}
}

View File

@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.lang.reflect.Method;
import javax.websocket.PongMessage;
import javax.websocket.Session;
/**
* PongMessage specific concrete implementation for handling whole messages.
*/
public class PojoMessageHandlerWholePong
extends PojoMessageHandlerWholeBase<PongMessage> {
public PojoMessageHandlerWholePong(Object pojo, Method method,
Session session, Object[] params, int indexPayload, boolean convert,
int indexSession) {
super(pojo, method, session, params, indexPayload, convert,
indexSession, -1);
}
@Override
protected Object decode(PongMessage message) {
// Never decoded
return null;
}
@Override
protected void onClose() {
// NO-OP
}
}

View File

@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.io.IOException;
import java.io.StringReader;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import javax.websocket.DecodeException;
import javax.websocket.Decoder;
import javax.websocket.Decoder.Text;
import javax.websocket.Decoder.TextStream;
import javax.websocket.EndpointConfig;
import javax.websocket.Session;
import org.apache.tomcat.util.res.StringManager;
import nginx.unit.websocket.Util;
/**
* Text specific concrete implementation for handling whole messages.
*/
public class PojoMessageHandlerWholeText
extends PojoMessageHandlerWholeBase<String> {
private static final StringManager sm =
StringManager.getManager(PojoMessageHandlerWholeText.class);
private final List<Decoder> decoders = new ArrayList<>();
private final Class<?> primitiveType;
public PojoMessageHandlerWholeText(Object pojo, Method method,
Session session, EndpointConfig config,
List<Class<? extends Decoder>> decoderClazzes, Object[] params,
int indexPayload, boolean convert, int indexSession,
long maxMessageSize) {
super(pojo, method, session, params, indexPayload, convert,
indexSession, maxMessageSize);
// Update max text size handled by session
if (maxMessageSize > -1 && maxMessageSize > session.getMaxTextMessageBufferSize()) {
if (maxMessageSize > Integer.MAX_VALUE) {
throw new IllegalArgumentException(sm.getString(
"pojoMessageHandlerWhole.maxBufferSize"));
}
session.setMaxTextMessageBufferSize((int) maxMessageSize);
}
// Check for primitives
Class<?> type = method.getParameterTypes()[indexPayload];
if (Util.isPrimitive(type)) {
primitiveType = type;
return;
} else {
primitiveType = null;
}
try {
if (decoderClazzes != null) {
for (Class<? extends Decoder> decoderClazz : decoderClazzes) {
if (Text.class.isAssignableFrom(decoderClazz)) {
Text<?> decoder = (Text<?>) decoderClazz.getConstructor().newInstance();
decoder.init(config);
decoders.add(decoder);
} else if (TextStream.class.isAssignableFrom(
decoderClazz)) {
TextStream<?> decoder =
(TextStream<?>) decoderClazz.getConstructor().newInstance();
decoder.init(config);
decoders.add(decoder);
} else {
// Binary decoder - ignore it
}
}
}
} catch (ReflectiveOperationException e) {
throw new IllegalArgumentException(e);
}
}
@Override
protected Object decode(String message) throws DecodeException {
// Handle primitives
if (primitiveType != null) {
return Util.coerceToType(primitiveType, message);
}
// Handle full decoders
for (Decoder decoder : decoders) {
if (decoder instanceof Text) {
if (((Text<?>) decoder).willDecode(message)) {
return ((Text<?>) decoder).decode(message);
}
} else {
StringReader r = new StringReader(message);
try {
return ((TextStream<?>) decoder).decode(r);
} catch (IOException ioe) {
throw new DecodeException(message, sm.getString(
"pojoMessageHandlerWhole.decodeIoFail"), ioe);
}
}
}
return null;
}
@Override
protected Object convert(String message) {
return new StringReader(message);
}
@Override
protected void onClose() {
for (Decoder decoder : decoders) {
decoder.destroy();
}
}
}

View File

@@ -0,0 +1,731 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
import java.io.InputStream;
import java.io.Reader;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.websocket.CloseReason;
import javax.websocket.DecodeException;
import javax.websocket.Decoder;
import javax.websocket.DeploymentException;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.PongMessage;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import org.apache.tomcat.util.res.StringManager;
import nginx.unit.websocket.DecoderEntry;
import nginx.unit.websocket.Util;
import nginx.unit.websocket.Util.DecoderMatch;
/**
* For a POJO class annotated with
* {@link javax.websocket.server.ServerEndpoint}, an instance of this class
* creates and caches the method handler, method information and parameter
* information for the onXXX calls.
*/
public class PojoMethodMapping {
private static final StringManager sm =
StringManager.getManager(PojoMethodMapping.class);
private final Method onOpen;
private final Method onClose;
private final Method onError;
private final PojoPathParam[] onOpenParams;
private final PojoPathParam[] onCloseParams;
private final PojoPathParam[] onErrorParams;
private final List<MessageHandlerInfo> onMessage = new ArrayList<>();
private final String wsPath;
public PojoMethodMapping(Class<?> clazzPojo,
List<Class<? extends Decoder>> decoderClazzes, String wsPath)
throws DeploymentException {
this.wsPath = wsPath;
List<DecoderEntry> decoders = Util.getDecoders(decoderClazzes);
Method open = null;
Method close = null;
Method error = null;
Method[] clazzPojoMethods = null;
Class<?> currentClazz = clazzPojo;
while (!currentClazz.equals(Object.class)) {
Method[] currentClazzMethods = currentClazz.getDeclaredMethods();
if (currentClazz == clazzPojo) {
clazzPojoMethods = currentClazzMethods;
}
for (Method method : currentClazzMethods) {
if (method.getAnnotation(OnOpen.class) != null) {
checkPublic(method);
if (open == null) {
open = method;
} else {
if (currentClazz == clazzPojo ||
!isMethodOverride(open, method)) {
// Duplicate annotation
throw new DeploymentException(sm.getString(
"pojoMethodMapping.duplicateAnnotation",
OnOpen.class, currentClazz));
}
}
} else if (method.getAnnotation(OnClose.class) != null) {
checkPublic(method);
if (close == null) {
close = method;
} else {
if (currentClazz == clazzPojo ||
!isMethodOverride(close, method)) {
// Duplicate annotation
throw new DeploymentException(sm.getString(
"pojoMethodMapping.duplicateAnnotation",
OnClose.class, currentClazz));
}
}
} else if (method.getAnnotation(OnError.class) != null) {
checkPublic(method);
if (error == null) {
error = method;
} else {
if (currentClazz == clazzPojo ||
!isMethodOverride(error, method)) {
// Duplicate annotation
throw new DeploymentException(sm.getString(
"pojoMethodMapping.duplicateAnnotation",
OnError.class, currentClazz));
}
}
} else if (method.getAnnotation(OnMessage.class) != null) {
checkPublic(method);
MessageHandlerInfo messageHandler = new MessageHandlerInfo(method, decoders);
boolean found = false;
for (MessageHandlerInfo otherMessageHandler : onMessage) {
if (messageHandler.targetsSameWebSocketMessageType(otherMessageHandler)) {
found = true;
if (currentClazz == clazzPojo ||
!isMethodOverride(messageHandler.m, otherMessageHandler.m)) {
// Duplicate annotation
throw new DeploymentException(sm.getString(
"pojoMethodMapping.duplicateAnnotation",
OnMessage.class, currentClazz));
}
}
}
if (!found) {
onMessage.add(messageHandler);
}
} else {
// Method not annotated
}
}
currentClazz = currentClazz.getSuperclass();
}
// If the methods are not on clazzPojo and they are overridden
// by a non annotated method in clazzPojo, they should be ignored
if (open != null && open.getDeclaringClass() != clazzPojo) {
if (isOverridenWithoutAnnotation(clazzPojoMethods, open, OnOpen.class)) {
open = null;
}
}
if (close != null && close.getDeclaringClass() != clazzPojo) {
if (isOverridenWithoutAnnotation(clazzPojoMethods, close, OnClose.class)) {
close = null;
}
}
if (error != null && error.getDeclaringClass() != clazzPojo) {
if (isOverridenWithoutAnnotation(clazzPojoMethods, error, OnError.class)) {
error = null;
}
}
List<MessageHandlerInfo> overriddenOnMessage = new ArrayList<>();
for (MessageHandlerInfo messageHandler : onMessage) {
if (messageHandler.m.getDeclaringClass() != clazzPojo
&& isOverridenWithoutAnnotation(clazzPojoMethods, messageHandler.m, OnMessage.class)) {
overriddenOnMessage.add(messageHandler);
}
}
for (MessageHandlerInfo messageHandler : overriddenOnMessage) {
onMessage.remove(messageHandler);
}
this.onOpen = open;
this.onClose = close;
this.onError = error;
onOpenParams = getPathParams(onOpen, MethodType.ON_OPEN);
onCloseParams = getPathParams(onClose, MethodType.ON_CLOSE);
onErrorParams = getPathParams(onError, MethodType.ON_ERROR);
}
private void checkPublic(Method m) throws DeploymentException {
if (!Modifier.isPublic(m.getModifiers())) {
throw new DeploymentException(sm.getString(
"pojoMethodMapping.methodNotPublic", m.getName()));
}
}
private boolean isMethodOverride(Method method1, Method method2) {
return method1.getName().equals(method2.getName())
&& method1.getReturnType().equals(method2.getReturnType())
&& Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes());
}
private boolean isOverridenWithoutAnnotation(Method[] methods,
Method superclazzMethod, Class<? extends Annotation> annotation) {
for (Method method : methods) {
if (isMethodOverride(method, superclazzMethod)
&& (method.getAnnotation(annotation) == null)) {
return true;
}
}
return false;
}
public String getWsPath() {
return wsPath;
}
public Method getOnOpen() {
return onOpen;
}
public Object[] getOnOpenArgs(Map<String,String> pathParameters,
Session session, EndpointConfig config) throws DecodeException {
return buildArgs(onOpenParams, pathParameters, session, config, null,
null);
}
public Method getOnClose() {
return onClose;
}
public Object[] getOnCloseArgs(Map<String,String> pathParameters,
Session session, CloseReason closeReason) throws DecodeException {
return buildArgs(onCloseParams, pathParameters, session, null, null,
closeReason);
}
public Method getOnError() {
return onError;
}
public Object[] getOnErrorArgs(Map<String,String> pathParameters,
Session session, Throwable throwable) throws DecodeException {
return buildArgs(onErrorParams, pathParameters, session, null,
throwable, null);
}
public boolean hasMessageHandlers() {
return !onMessage.isEmpty();
}
public Set<MessageHandler> getMessageHandlers(Object pojo,
Map<String,String> pathParameters, Session session,
EndpointConfig config) {
Set<MessageHandler> result = new HashSet<>();
for (MessageHandlerInfo messageMethod : onMessage) {
result.addAll(messageMethod.getMessageHandlers(pojo, pathParameters,
session, config));
}
return result;
}
private static PojoPathParam[] getPathParams(Method m,
MethodType methodType) throws DeploymentException {
if (m == null) {
return new PojoPathParam[0];
}
boolean foundThrowable = false;
Class<?>[] types = m.getParameterTypes();
Annotation[][] paramsAnnotations = m.getParameterAnnotations();
PojoPathParam[] result = new PojoPathParam[types.length];
for (int i = 0; i < types.length; i++) {
Class<?> type = types[i];
if (type.equals(Session.class)) {
result[i] = new PojoPathParam(type, null);
} else if (methodType == MethodType.ON_OPEN &&
type.equals(EndpointConfig.class)) {
result[i] = new PojoPathParam(type, null);
} else if (methodType == MethodType.ON_ERROR
&& type.equals(Throwable.class)) {
foundThrowable = true;
result[i] = new PojoPathParam(type, null);
} else if (methodType == MethodType.ON_CLOSE &&
type.equals(CloseReason.class)) {
result[i] = new PojoPathParam(type, null);
} else {
Annotation[] paramAnnotations = paramsAnnotations[i];
for (Annotation paramAnnotation : paramAnnotations) {
if (paramAnnotation.annotationType().equals(
PathParam.class)) {
// Check that the type is valid. "0" coerces to every
// valid type
try {
Util.coerceToType(type, "0");
} catch (IllegalArgumentException iae) {
throw new DeploymentException(sm.getString(
"pojoMethodMapping.invalidPathParamType"),
iae);
}
result[i] = new PojoPathParam(type,
((PathParam) paramAnnotation).value());
break;
}
}
// Parameters without annotations are not permitted
if (result[i] == null) {
throw new DeploymentException(sm.getString(
"pojoMethodMapping.paramWithoutAnnotation",
type, m.getName(), m.getClass().getName()));
}
}
}
if (methodType == MethodType.ON_ERROR && !foundThrowable) {
throw new DeploymentException(sm.getString(
"pojoMethodMapping.onErrorNoThrowable",
m.getName(), m.getDeclaringClass().getName()));
}
return result;
}
private static Object[] buildArgs(PojoPathParam[] pathParams,
Map<String,String> pathParameters, Session session,
EndpointConfig config, Throwable throwable, CloseReason closeReason)
throws DecodeException {
Object[] result = new Object[pathParams.length];
for (int i = 0; i < pathParams.length; i++) {
Class<?> type = pathParams[i].getType();
if (type.equals(Session.class)) {
result[i] = session;
} else if (type.equals(EndpointConfig.class)) {
result[i] = config;
} else if (type.equals(Throwable.class)) {
result[i] = throwable;
} else if (type.equals(CloseReason.class)) {
result[i] = closeReason;
} else {
String name = pathParams[i].getName();
String value = pathParameters.get(name);
try {
result[i] = Util.coerceToType(type, value);
} catch (Exception e) {
throw new DecodeException(value, sm.getString(
"pojoMethodMapping.decodePathParamFail",
value, type), e);
}
}
}
return result;
}
private static class MessageHandlerInfo {
private final Method m;
private int indexString = -1;
private int indexByteArray = -1;
private int indexByteBuffer = -1;
private int indexPong = -1;
private int indexBoolean = -1;
private int indexSession = -1;
private int indexInputStream = -1;
private int indexReader = -1;
private int indexPrimitive = -1;
private Class<?> primitiveType = null;
private Map<Integer,PojoPathParam> indexPathParams = new HashMap<>();
private int indexPayload = -1;
private DecoderMatch decoderMatch = null;
private long maxMessageSize = -1;
public MessageHandlerInfo(Method m, List<DecoderEntry> decoderEntries) {
this.m = m;
Class<?>[] types = m.getParameterTypes();
Annotation[][] paramsAnnotations = m.getParameterAnnotations();
for (int i = 0; i < types.length; i++) {
boolean paramFound = false;
Annotation[] paramAnnotations = paramsAnnotations[i];
for (Annotation paramAnnotation : paramAnnotations) {
if (paramAnnotation.annotationType().equals(
PathParam.class)) {
indexPathParams.put(
Integer.valueOf(i), new PojoPathParam(types[i],
((PathParam) paramAnnotation).value()));
paramFound = true;
break;
}
}
if (paramFound) {
continue;
}
if (String.class.isAssignableFrom(types[i])) {
if (indexString == -1) {
indexString = i;
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else if (Reader.class.isAssignableFrom(types[i])) {
if (indexReader == -1) {
indexReader = i;
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else if (boolean.class == types[i]) {
if (indexBoolean == -1) {
indexBoolean = i;
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateLastParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else if (ByteBuffer.class.isAssignableFrom(types[i])) {
if (indexByteBuffer == -1) {
indexByteBuffer = i;
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else if (byte[].class == types[i]) {
if (indexByteArray == -1) {
indexByteArray = i;
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else if (InputStream.class.isAssignableFrom(types[i])) {
if (indexInputStream == -1) {
indexInputStream = i;
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else if (Util.isPrimitive(types[i])) {
if (indexPrimitive == -1) {
indexPrimitive = i;
primitiveType = types[i];
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else if (Session.class.isAssignableFrom(types[i])) {
if (indexSession == -1) {
indexSession = i;
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateSessionParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else if (PongMessage.class.isAssignableFrom(types[i])) {
if (indexPong == -1) {
indexPong = i;
} else {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicatePongMessageParam",
m.getName(), m.getDeclaringClass().getName()));
}
} else {
if (decoderMatch != null && decoderMatch.hasMatches()) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
}
decoderMatch = new DecoderMatch(types[i], decoderEntries);
if (decoderMatch.hasMatches()) {
indexPayload = i;
}
}
}
// Additional checks required
if (indexString != -1) {
if (indexPayload != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
} else {
indexPayload = indexString;
}
}
if (indexReader != -1) {
if (indexPayload != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
} else {
indexPayload = indexReader;
}
}
if (indexByteArray != -1) {
if (indexPayload != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
} else {
indexPayload = indexByteArray;
}
}
if (indexByteBuffer != -1) {
if (indexPayload != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
} else {
indexPayload = indexByteBuffer;
}
}
if (indexInputStream != -1) {
if (indexPayload != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
} else {
indexPayload = indexInputStream;
}
}
if (indexPrimitive != -1) {
if (indexPayload != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.duplicateMessageParam",
m.getName(), m.getDeclaringClass().getName()));
} else {
indexPayload = indexPrimitive;
}
}
if (indexPong != -1) {
if (indexPayload != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.pongWithPayload",
m.getName(), m.getDeclaringClass().getName()));
} else {
indexPayload = indexPong;
}
}
if (indexPayload == -1 && indexPrimitive == -1 &&
indexBoolean != -1) {
// The boolean we found is a payload, not a last flag
indexPayload = indexBoolean;
indexPrimitive = indexBoolean;
primitiveType = Boolean.TYPE;
indexBoolean = -1;
}
if (indexPayload == -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.noPayload",
m.getName(), m.getDeclaringClass().getName()));
}
if (indexPong != -1 && indexBoolean != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.partialPong",
m.getName(), m.getDeclaringClass().getName()));
}
if(indexReader != -1 && indexBoolean != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.partialReader",
m.getName(), m.getDeclaringClass().getName()));
}
if(indexInputStream != -1 && indexBoolean != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.partialInputStream",
m.getName(), m.getDeclaringClass().getName()));
}
if (decoderMatch != null && decoderMatch.hasMatches() &&
indexBoolean != -1) {
throw new IllegalArgumentException(sm.getString(
"pojoMethodMapping.partialObject",
m.getName(), m.getDeclaringClass().getName()));
}
maxMessageSize = m.getAnnotation(OnMessage.class).maxMessageSize();
}
public boolean targetsSameWebSocketMessageType(MessageHandlerInfo otherHandler) {
if (otherHandler == null) {
return false;
}
if (indexByteArray >= 0 && otherHandler.indexByteArray >= 0) {
return true;
}
if (indexByteBuffer >= 0 && otherHandler.indexByteBuffer >= 0) {
return true;
}
if (indexInputStream >= 0 && otherHandler.indexInputStream >= 0) {
return true;
}
if (indexPong >= 0 && otherHandler.indexPong >= 0) {
return true;
}
if (indexPrimitive >= 0 && otherHandler.indexPrimitive >= 0
&& primitiveType == otherHandler.primitiveType) {
return true;
}
if (indexReader >= 0 && otherHandler.indexReader >= 0) {
return true;
}
if (indexString >= 0 && otherHandler.indexString >= 0) {
return true;
}
if (decoderMatch != null && otherHandler.decoderMatch != null
&& decoderMatch.getTarget().equals(otherHandler.decoderMatch.getTarget())) {
return true;
}
return false;
}
public Set<MessageHandler> getMessageHandlers(Object pojo,
Map<String,String> pathParameters, Session session,
EndpointConfig config) {
Object[] params = new Object[m.getParameterTypes().length];
for (Map.Entry<Integer,PojoPathParam> entry :
indexPathParams.entrySet()) {
PojoPathParam pathParam = entry.getValue();
String valueString = pathParameters.get(pathParam.getName());
Object value = null;
try {
value = Util.coerceToType(pathParam.getType(), valueString);
} catch (Exception e) {
DecodeException de = new DecodeException(valueString,
sm.getString(
"pojoMethodMapping.decodePathParamFail",
valueString, pathParam.getType()), e);
params = new Object[] { de };
break;
}
params[entry.getKey().intValue()] = value;
}
Set<MessageHandler> results = new HashSet<>(2);
if (indexBoolean == -1) {
// Basic
if (indexString != -1 || indexPrimitive != -1) {
MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
session, config, null, params, indexPayload, false,
indexSession, maxMessageSize);
results.add(mh);
} else if (indexReader != -1) {
MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
session, config, null, params, indexReader, true,
indexSession, maxMessageSize);
results.add(mh);
} else if (indexByteArray != -1) {
MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
m, session, config, null, params, indexByteArray,
true, indexSession, false, maxMessageSize);
results.add(mh);
} else if (indexByteBuffer != -1) {
MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
m, session, config, null, params, indexByteBuffer,
false, indexSession, false, maxMessageSize);
results.add(mh);
} else if (indexInputStream != -1) {
MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
m, session, config, null, params, indexInputStream,
true, indexSession, true, maxMessageSize);
results.add(mh);
} else if (decoderMatch != null && decoderMatch.hasMatches()) {
if (decoderMatch.getBinaryDecoders().size() > 0) {
MessageHandler mh = new PojoMessageHandlerWholeBinary(
pojo, m, session, config,
decoderMatch.getBinaryDecoders(), params,
indexPayload, true, indexSession, true,
maxMessageSize);
results.add(mh);
}
if (decoderMatch.getTextDecoders().size() > 0) {
MessageHandler mh = new PojoMessageHandlerWholeText(
pojo, m, session, config,
decoderMatch.getTextDecoders(), params,
indexPayload, true, indexSession, maxMessageSize);
results.add(mh);
}
} else {
MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m,
session, params, indexPong, false, indexSession);
results.add(mh);
}
} else {
// ASync
if (indexString != -1) {
MessageHandler mh = new PojoMessageHandlerPartialText(pojo,
m, session, params, indexString, false,
indexBoolean, indexSession, maxMessageSize);
results.add(mh);
} else if (indexByteArray != -1) {
MessageHandler mh = new PojoMessageHandlerPartialBinary(
pojo, m, session, params, indexByteArray, true,
indexBoolean, indexSession, maxMessageSize);
results.add(mh);
} else {
MessageHandler mh = new PojoMessageHandlerPartialBinary(
pojo, m, session, params, indexByteBuffer, false,
indexBoolean, indexSession, maxMessageSize);
results.add(mh);
}
}
return results;
}
}
private enum MethodType {
ON_OPEN,
ON_CLOSE,
ON_ERROR
}
}

View File

@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.pojo;
/**
* Stores the parameter type and name for a parameter that needs to be passed to
* an onXxx method of {@link javax.websocket.Endpoint}. The name is only present
* for parameters annotated with
* {@link javax.websocket.server.PathParam}. For the
* {@link javax.websocket.Session} and {@link java.lang.Throwable} parameters,
* {@link #getName()} will always return <code>null</code>.
*/
public class PojoPathParam {
private final Class<?> type;
private final String name;
public PojoPathParam(Class<?> type, String name) {
this.type = type;
this.name = name;
}
public Class<?> getType() {
return type;
}
public String getName() {
return name;
}
}

View File

@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* This package provides the necessary plumbing to convert an annotated POJO
* into a WebSocket {@link javax.websocket.Endpoint}.
*/
package nginx.unit.websocket.pojo;

View File

@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
/**
* Internal implementation constants.
*/
public class Constants {
public static final String BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM =
"nginx.unit.websocket.binaryBufferSize";
public static final String TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM =
"nginx.unit.websocket.textBufferSize";
public static final String ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM =
"nginx.unit.websocket.noAddAfterHandshake";
public static final String SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE =
"javax.websocket.server.ServerContainer";
private Constants() {
// Hide default constructor
}
}

View File

@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;
public class DefaultServerEndpointConfigurator
extends ServerEndpointConfig.Configurator {
@Override
public <T> T getEndpointInstance(Class<T> clazz)
throws InstantiationException {
try {
return clazz.getConstructor().newInstance();
} catch (InstantiationException e) {
throw e;
} catch (ReflectiveOperationException e) {
InstantiationException ie = new InstantiationException();
ie.initCause(e);
throw ie;
}
}
@Override
public String getNegotiatedSubprotocol(List<String> supported,
List<String> requested) {
for (String request : requested) {
if (supported.contains(request)) {
return request;
}
}
return "";
}
@Override
public List<Extension> getNegotiatedExtensions(List<Extension> installed,
List<Extension> requested) {
Set<String> installedNames = new HashSet<>();
for (Extension e : installed) {
installedNames.add(e.getName());
}
List<Extension> result = new ArrayList<>();
for (Extension request : requested) {
if (installedNames.contains(request.getName())) {
result.add(request);
}
}
return result;
}
@Override
public boolean checkOrigin(String originHeaderValue) {
return true;
}
@Override
public void modifyHandshake(ServerEndpointConfig sec,
HandshakeRequest request, HandshakeResponse response) {
// NO-OP
}
}

View File

@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
serverContainer.addNotAllowed=No further Endpoints may be registered once an attempt has been made to use one of the previously registered endpoints
serverContainer.configuratorFail=Failed to create configurator of type [{0}] for POJO of type [{1}]
serverContainer.duplicatePaths=Multiple Endpoints may not be deployed to the same path [{0}] : existing endpoint was [{1}] and new endpoint is [{2}]
serverContainer.encoderFail=Unable to create encoder of type [{0}]
serverContainer.endpointDeploy=Endpoint class [{0}] deploying to path [{1}] in ServletContext [{2}]
serverContainer.missingAnnotation=Cannot deploy POJO class [{0}] as it is not annotated with @ServerEndpoint
serverContainer.missingEndpoint=An Endpoint instance has been request for path [{0}] but no matching Endpoint class was found
serverContainer.pojoDeploy=POJO class [{0}] deploying to path [{1}] in ServletContext [{2}]
serverContainer.servletContextMismatch=Attempted to register a POJO annotated for WebSocket at path [{0}] in the ServletContext with context path [{1}] when the WebSocket ServerContainer is allocated to the ServletContext with context path [{2}]
serverContainer.servletContextMissing=No ServletContext was specified
upgradeUtil.incompatibleRsv=Extensions were specified that have incompatible RSV bit usage
uriTemplate.duplicateParameter=The parameter [{0}] appears more than once in the path which is not permitted
uriTemplate.emptySegment=The path [{0}] contains one or more empty segments which are is not permitted
uriTemplate.invalidPath=The path [{0}] is not valid.
uriTemplate.invalidSegment=The segment [{0}] is not valid in the provided path [{1}]
wsFrameServer.bytesRead=Read [{0}] bytes into input buffer ready for processing
wsFrameServer.illegalReadState=Unexpected read state [{0}]
wsFrameServer.onDataAvailable=Method entry
wsHttpUpgradeHandler.closeOnError=Closing WebSocket connection due to an error
wsHttpUpgradeHandler.destroyFailed=Failed to close WebConnection while destroying the WebSocket HttpUpgradeHandler
wsHttpUpgradeHandler.noPreInit=The preInit() method must be called to configure the WebSocket HttpUpgradeHandler before the container calls init(). Usually, this means the Servlet that created the WsHttpUpgradeHandler instance should also call preInit()
wsHttpUpgradeHandler.serverStop=The server is stopping
wsRemoteEndpointServer.closeFailed=Failed to close the ServletOutputStream connection cleanly

View File

@@ -0,0 +1,285 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.Endpoint;
import javax.websocket.Extension;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.ServerEndpointConfig;
import nginx.unit.Request;
import org.apache.tomcat.util.codec.binary.Base64;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.util.security.ConcurrentMessageDigest;
import nginx.unit.websocket.Constants;
import nginx.unit.websocket.Transformation;
import nginx.unit.websocket.TransformationFactory;
import nginx.unit.websocket.Util;
import nginx.unit.websocket.WsHandshakeResponse;
import nginx.unit.websocket.pojo.PojoEndpointServer;
public class UpgradeUtil {
private static final StringManager sm =
StringManager.getManager(UpgradeUtil.class.getPackage().getName());
private static final byte[] WS_ACCEPT =
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(
StandardCharsets.ISO_8859_1);
private UpgradeUtil() {
// Utility class. Hide default constructor.
}
/**
* Checks to see if this is an HTTP request that includes a valid upgrade
* request to web socket.
* <p>
* Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java
* WebSocket spec 1.0, section 8.2 implies such a limitation and RFC
* 6455 section 4.1 requires that a WebSocket Upgrade uses GET.
* @param request The request to check if it is an HTTP upgrade request for
* a WebSocket connection
* @param response The response associated with the request
* @return <code>true</code> if the request includes a HTTP Upgrade request
* for the WebSocket protocol, otherwise <code>false</code>
*/
public static boolean isWebSocketUpgradeRequest(ServletRequest request,
ServletResponse response) {
Request r = (Request) request.getAttribute(Request.BARE);
return ((request instanceof HttpServletRequest) &&
(response instanceof HttpServletResponse) &&
(r != null) &&
(r.isUpgrade()));
}
public static void doUpgrade(WsServerContainer sc, HttpServletRequest req,
HttpServletResponse resp, ServerEndpointConfig sec,
Map<String,String> pathParams)
throws ServletException, IOException {
// Origin check
String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME);
if (!sec.getConfigurator().checkOrigin(origin)) {
resp.sendError(HttpServletResponse.SC_FORBIDDEN);
return;
}
// Sub-protocols
List<String> subProtocols = getTokensFromHeader(req,
Constants.WS_PROTOCOL_HEADER_NAME);
String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol(
sec.getSubprotocols(), subProtocols);
// Extensions
// Should normally only be one header but handle the case of multiple
// headers
List<Extension> extensionsRequested = new ArrayList<>();
Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME);
while (extHeaders.hasMoreElements()) {
Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement());
}
// Negotiation phase 1. By default this simply filters out the
// extensions that the server does not support but applications could
// use a custom configurator to do more than this.
List<Extension> installedExtensions = null;
if (sec.getExtensions().size() == 0) {
installedExtensions = Constants.INSTALLED_EXTENSIONS;
} else {
installedExtensions = new ArrayList<>();
installedExtensions.addAll(sec.getExtensions());
installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS);
}
List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions(
installedExtensions, extensionsRequested);
// Negotiation phase 2. Create the Transformations that will be applied
// to this connection. Note than an extension may be dropped at this
// point if the client has requested a configuration that the server is
// unable to support.
List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1);
List<Extension> negotiatedExtensionsPhase2;
if (transformations.isEmpty()) {
negotiatedExtensionsPhase2 = Collections.emptyList();
} else {
negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size());
for (Transformation t : transformations) {
negotiatedExtensionsPhase2.add(t.getExtensionResponse());
}
}
WsHttpUpgradeHandler wsHandler =
req.upgrade(WsHttpUpgradeHandler.class);
WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams);
WsHandshakeResponse wsResponse = new WsHandshakeResponse();
WsPerSessionServerEndpointConfig perSessionServerEndpointConfig =
new WsPerSessionServerEndpointConfig(sec);
sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig,
wsRequest, wsResponse);
//wsRequest.finished();
// Add any additional headers
for (Entry<String,List<String>> entry :
wsResponse.getHeaders().entrySet()) {
for (String headerValue: entry.getValue()) {
resp.addHeader(entry.getKey(), headerValue);
}
}
Endpoint ep;
try {
Class<?> clazz = sec.getEndpointClass();
if (Endpoint.class.isAssignableFrom(clazz)) {
ep = (Endpoint) sec.getConfigurator().getEndpointInstance(
clazz);
} else {
ep = new PojoEndpointServer();
// Need to make path params available to POJO
perSessionServerEndpointConfig.getUserProperties().put(
nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams);
}
} catch (InstantiationException e) {
throw new ServletException(e);
}
wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest,
negotiatedExtensionsPhase2, subProtocol, null, pathParams,
req.isSecure());
wsHandler.init(null);
}
private static List<Transformation> createTransformations(
List<Extension> negotiatedExtensions) {
TransformationFactory factory = TransformationFactory.getInstance();
LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences =
new LinkedHashMap<>();
// Result will likely be smaller than this
List<Transformation> result = new ArrayList<>(negotiatedExtensions.size());
for (Extension extension : negotiatedExtensions) {
List<List<Extension.Parameter>> preferences =
extensionPreferences.get(extension.getName());
if (preferences == null) {
preferences = new ArrayList<>();
extensionPreferences.put(extension.getName(), preferences);
}
preferences.add(extension.getParameters());
}
for (Map.Entry<String,List<List<Extension.Parameter>>> entry :
extensionPreferences.entrySet()) {
Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true);
if (transformation != null) {
result.add(transformation);
}
}
return result;
}
private static void append(StringBuilder sb, Extension extension) {
if (extension == null || extension.getName() == null || extension.getName().length() == 0) {
return;
}
sb.append(extension.getName());
for (Extension.Parameter p : extension.getParameters()) {
sb.append(';');
sb.append(p.getName());
if (p.getValue() != null) {
sb.append('=');
sb.append(p.getValue());
}
}
}
/*
* This only works for tokens. Quoted strings need more sophisticated
* parsing.
*/
private static boolean headerContainsToken(HttpServletRequest req,
String headerName, String target) {
Enumeration<String> headers = req.getHeaders(headerName);
while (headers.hasMoreElements()) {
String header = headers.nextElement();
String[] tokens = header.split(",");
for (String token : tokens) {
if (target.equalsIgnoreCase(token.trim())) {
return true;
}
}
}
return false;
}
/*
* This only works for tokens. Quoted strings need more sophisticated
* parsing.
*/
private static List<String> getTokensFromHeader(HttpServletRequest req,
String headerName) {
List<String> result = new ArrayList<>();
Enumeration<String> headers = req.getHeaders(headerName);
while (headers.hasMoreElements()) {
String header = headers.nextElement();
String[] tokens = header.split(",");
for (String token : tokens) {
result.add(token.trim());
}
}
return result;
}
private static String getWebSocketAccept(String key) {
byte[] digest = ConcurrentMessageDigest.digestSHA1(
key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT);
return Base64.encodeBase64String(digest);
}
}

View File

@@ -0,0 +1,177 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.websocket.DeploymentException;
import org.apache.tomcat.util.res.StringManager;
/**
* Extracts path parameters from URIs used to create web socket connections
* using the URI template defined for the associated Endpoint.
*/
public class UriTemplate {
private static final StringManager sm = StringManager.getManager(UriTemplate.class);
private final String normalized;
private final List<Segment> segments = new ArrayList<>();
private final boolean hasParameters;
public UriTemplate(String path) throws DeploymentException {
if (path == null || path.length() ==0 || !path.startsWith("/")) {
throw new DeploymentException(
sm.getString("uriTemplate.invalidPath", path));
}
StringBuilder normalized = new StringBuilder(path.length());
Set<String> paramNames = new HashSet<>();
// Include empty segments.
String[] segments = path.split("/", -1);
int paramCount = 0;
int segmentCount = 0;
for (int i = 0; i < segments.length; i++) {
String segment = segments[i];
if (segment.length() == 0) {
if (i == 0 || (i == segments.length - 1 && paramCount == 0)) {
// Ignore the first empty segment as the path must always
// start with '/'
// Ending with a '/' is also OK for instances used for
// matches but not for parameterised templates.
continue;
} else {
// As per EG discussion, all other empty segments are
// invalid
throw new IllegalArgumentException(sm.getString(
"uriTemplate.emptySegment", path));
}
}
normalized.append('/');
int index = -1;
if (segment.startsWith("{") && segment.endsWith("}")) {
index = segmentCount;
segment = segment.substring(1, segment.length() - 1);
normalized.append('{');
normalized.append(paramCount++);
normalized.append('}');
if (!paramNames.add(segment)) {
throw new IllegalArgumentException(sm.getString(
"uriTemplate.duplicateParameter", segment));
}
} else {
if (segment.contains("{") || segment.contains("}")) {
throw new IllegalArgumentException(sm.getString(
"uriTemplate.invalidSegment", segment, path));
}
normalized.append(segment);
}
this.segments.add(new Segment(index, segment));
segmentCount++;
}
this.normalized = normalized.toString();
this.hasParameters = paramCount > 0;
}
public Map<String,String> match(UriTemplate candidate) {
Map<String,String> result = new HashMap<>();
// Should not happen but for safety
if (candidate.getSegmentCount() != getSegmentCount()) {
return null;
}
Iterator<Segment> candidateSegments =
candidate.getSegments().iterator();
Iterator<Segment> targetSegments = segments.iterator();
while (candidateSegments.hasNext()) {
Segment candidateSegment = candidateSegments.next();
Segment targetSegment = targetSegments.next();
if (targetSegment.getParameterIndex() == -1) {
// Not a parameter - values must match
if (!targetSegment.getValue().equals(
candidateSegment.getValue())) {
// Not a match. Stop here
return null;
}
} else {
// Parameter
result.put(targetSegment.getValue(),
candidateSegment.getValue());
}
}
return result;
}
public boolean hasParameters() {
return hasParameters;
}
public int getSegmentCount() {
return segments.size();
}
public String getNormalizedPath() {
return normalized;
}
private List<Segment> getSegments() {
return segments;
}
private static class Segment {
private final int parameterIndex;
private final String value;
public Segment(int parameterIndex, String value) {
this.parameterIndex = parameterIndex;
this.value = value;
}
public int getParameterIndex() {
return parameterIndex;
}
public String getValue() {
return value;
}
}
}

View File

@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import javax.servlet.ServletContext;
import javax.servlet.ServletContextEvent;
import javax.servlet.ServletContextListener;
/**
* In normal usage, this {@link ServletContextListener} does not need to be
* explicitly configured as the {@link WsSci} performs all the necessary
* bootstrap and installs this listener in the {@link ServletContext}. If the
* {@link WsSci} is disabled, this listener must be added manually to every
* {@link ServletContext} that uses WebSocket to bootstrap the
* {@link WsServerContainer} correctly.
*/
public class WsContextListener implements ServletContextListener {
@Override
public void contextInitialized(ServletContextEvent sce) {
ServletContext sc = sce.getServletContext();
// Don't trigger WebSocket initialization if a WebSocket Server
// Container is already present
if (sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE) == null) {
WsSci.init(sce.getServletContext(), false);
}
}
@Override
public void contextDestroyed(ServletContextEvent sce) {
ServletContext sc = sce.getServletContext();
Object obj = sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
if (obj instanceof WsServerContainer) {
((WsServerContainer) obj).destroy();
}
}
}

View File

@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.io.IOException;
import javax.servlet.FilterChain;
import javax.servlet.GenericFilter;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Handles the initial HTTP connection for WebSocket connections.
*/
public class WsFilter extends GenericFilter {
private static final long serialVersionUID = 1L;
private transient WsServerContainer sc;
@Override
public void init() throws ServletException {
sc = (WsServerContainer) getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
}
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
// This filter only needs to handle WebSocket upgrade requests
if (!sc.areEndpointsRegistered() ||
!UpgradeUtil.isWebSocketUpgradeRequest(request, response)) {
chain.doFilter(request, response);
return;
}
// HTTP request with an upgrade header for WebSocket present
HttpServletRequest req = (HttpServletRequest) request;
HttpServletResponse resp = (HttpServletResponse) response;
// Check to see if this WebSocket implementation has a matching mapping
String path;
String pathInfo = req.getPathInfo();
if (pathInfo == null) {
path = req.getServletPath();
} else {
path = req.getServletPath() + pathInfo;
}
WsMappingResult mappingResult = sc.findMapping(path);
if (mappingResult == null) {
// No endpoint registered for the requested path. Let the
// application handle it (it might redirect or forward for example)
chain.doFilter(request, response);
return;
}
UpgradeUtil.doUpgrade(sc, req, resp, mappingResult.getConfig(),
mappingResult.getPathParams());
}
}

View File

@@ -0,0 +1,196 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.Principal;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import javax.servlet.http.HttpServletRequest;
import javax.websocket.server.HandshakeRequest;
import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
import org.apache.tomcat.util.res.StringManager;
/**
* Represents the request that this session was opened under.
*/
public class WsHandshakeRequest implements HandshakeRequest {
private static final StringManager sm = StringManager.getManager(WsHandshakeRequest.class);
private final URI requestUri;
private final Map<String,List<String>> parameterMap;
private final String queryString;
private final Principal userPrincipal;
private final Map<String,List<String>> headers;
private final Object httpSession;
private volatile HttpServletRequest request;
public WsHandshakeRequest(HttpServletRequest request, Map<String,String> pathParams) {
this.request = request;
queryString = request.getQueryString();
userPrincipal = request.getUserPrincipal();
httpSession = request.getSession(false);
requestUri = buildRequestUri(request);
// ParameterMap
Map<String,String[]> originalParameters = request.getParameterMap();
Map<String,List<String>> newParameters =
new HashMap<>(originalParameters.size());
for (Entry<String,String[]> entry : originalParameters.entrySet()) {
newParameters.put(entry.getKey(),
Collections.unmodifiableList(
Arrays.asList(entry.getValue())));
}
for (Entry<String,String> entry : pathParams.entrySet()) {
newParameters.put(entry.getKey(),
Collections.unmodifiableList(
Collections.singletonList(entry.getValue())));
}
parameterMap = Collections.unmodifiableMap(newParameters);
// Headers
Map<String,List<String>> newHeaders = new CaseInsensitiveKeyMap<>();
Enumeration<String> headerNames = request.getHeaderNames();
while (headerNames.hasMoreElements()) {
String headerName = headerNames.nextElement();
newHeaders.put(headerName, Collections.unmodifiableList(
Collections.list(request.getHeaders(headerName))));
}
headers = Collections.unmodifiableMap(newHeaders);
}
@Override
public URI getRequestURI() {
return requestUri;
}
@Override
public Map<String,List<String>> getParameterMap() {
return parameterMap;
}
@Override
public String getQueryString() {
return queryString;
}
@Override
public Principal getUserPrincipal() {
return userPrincipal;
}
@Override
public Map<String,List<String>> getHeaders() {
return headers;
}
@Override
public boolean isUserInRole(String role) {
if (request == null) {
throw new IllegalStateException();
}
return request.isUserInRole(role);
}
@Override
public Object getHttpSession() {
return httpSession;
}
/**
* Called when the HandshakeRequest is no longer required. Since an instance
* of this class retains a reference to the current HttpServletRequest that
* reference needs to be cleared as the HttpServletRequest may be reused.
*
* There is no reason for instances of this class to be accessed once the
* handshake has been completed.
*/
void finished() {
request = null;
}
/*
* See RequestUtil.getRequestURL()
*/
private static URI buildRequestUri(HttpServletRequest req) {
StringBuffer uri = new StringBuffer();
String scheme = req.getScheme();
int port = req.getServerPort();
if (port < 0) {
// Work around java.net.URL bug
port = 80;
}
if ("http".equals(scheme)) {
uri.append("ws");
} else if ("https".equals(scheme)) {
uri.append("wss");
} else {
// Should never happen
throw new IllegalArgumentException(
sm.getString("wsHandshakeRequest.unknownScheme", scheme));
}
uri.append("://");
uri.append(req.getServerName());
if ((scheme.equals("http") && (port != 80))
|| (scheme.equals("https") && (port != 443))) {
uri.append(':');
uri.append(port);
}
uri.append(req.getRequestURI());
if (req.getQueryString() != null) {
uri.append("?");
uri.append(req.getQueryString());
}
try {
return new URI(uri.toString());
} catch (URISyntaxException e) {
// Should never happen
throw new IllegalArgumentException(
sm.getString("wsHandshakeRequest.invalidUri", uri.toString()), e);
}
}
public Object getAttribute(String name)
{
return request != null ? request.getAttribute(name) : null;
}
}

View File

@@ -0,0 +1,172 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpUpgradeHandler;
import javax.servlet.http.WebConnection;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.res.StringManager;
import nginx.unit.websocket.Transformation;
import nginx.unit.websocket.WsIOException;
import nginx.unit.websocket.WsSession;
import nginx.unit.Request;
/**
* Servlet 3.1 HTTP upgrade handler for WebSocket connections.
*/
public class WsHttpUpgradeHandler implements HttpUpgradeHandler {
private final Log log = LogFactory.getLog(WsHttpUpgradeHandler.class); // must not be static
private static final StringManager sm = StringManager.getManager(WsHttpUpgradeHandler.class);
private final ClassLoader applicationClassLoader;
private Endpoint ep;
private EndpointConfig endpointConfig;
private WsServerContainer webSocketContainer;
private WsHandshakeRequest handshakeRequest;
private List<Extension> negotiatedExtensions;
private String subProtocol;
private Transformation transformation;
private Map<String,String> pathParameters;
private boolean secure;
private WebConnection connection;
private WsRemoteEndpointImplServer wsRemoteEndpointServer;
private WsSession wsSession;
public WsHttpUpgradeHandler() {
applicationClassLoader = Thread.currentThread().getContextClassLoader();
}
public void preInit(Endpoint ep, EndpointConfig endpointConfig,
WsServerContainer wsc, WsHandshakeRequest handshakeRequest,
List<Extension> negotiatedExtensionsPhase2, String subProtocol,
Transformation transformation, Map<String,String> pathParameters,
boolean secure) {
this.ep = ep;
this.endpointConfig = endpointConfig;
this.webSocketContainer = wsc;
this.handshakeRequest = handshakeRequest;
this.negotiatedExtensions = negotiatedExtensionsPhase2;
this.subProtocol = subProtocol;
this.transformation = transformation;
this.pathParameters = pathParameters;
this.secure = secure;
}
@Override
public void init(WebConnection connection) {
if (ep == null) {
throw new IllegalStateException(
sm.getString("wsHttpUpgradeHandler.noPreInit"));
}
String httpSessionId = null;
Object session = handshakeRequest.getHttpSession();
if (session != null ) {
httpSessionId = ((HttpSession) session).getId();
}
nginx.unit.Context.trace("UpgradeHandler.init(" + connection + ")");
/*
// Need to call onOpen using the web application's class loader
// Create the frame using the application's class loader so it can pick
// up application specific config from the ServerContainerImpl
Thread t = Thread.currentThread();
ClassLoader cl = t.getContextClassLoader();
t.setContextClassLoader(applicationClassLoader);
*/
try {
Request r = (Request) handshakeRequest.getAttribute(Request.BARE);
wsRemoteEndpointServer = new WsRemoteEndpointImplServer(webSocketContainer);
wsSession = new WsSession(ep, wsRemoteEndpointServer,
webSocketContainer, handshakeRequest.getRequestURI(),
handshakeRequest.getParameterMap(),
handshakeRequest.getQueryString(),
handshakeRequest.getUserPrincipal(), httpSessionId,
negotiatedExtensions, subProtocol, pathParameters, secure,
endpointConfig, r);
ep.onOpen(wsSession, endpointConfig);
webSocketContainer.registerSession(ep, wsSession);
} catch (DeploymentException e) {
throw new IllegalArgumentException(e);
/*
} finally {
t.setContextClassLoader(cl);
*/
}
}
@Override
public void destroy() {
if (connection != null) {
try {
connection.close();
} catch (Exception e) {
log.error(sm.getString("wsHttpUpgradeHandler.destroyFailed"), e);
}
}
}
private void onError(Throwable throwable) {
// Need to call onError using the web application's class loader
Thread t = Thread.currentThread();
ClassLoader cl = t.getContextClassLoader();
t.setContextClassLoader(applicationClassLoader);
try {
ep.onError(wsSession, throwable);
} finally {
t.setContextClassLoader(cl);
}
}
private void close(CloseReason cr) {
/*
* Any call to this method is a result of a problem reading from the
* client. At this point that state of the connection is unknown.
* Attempt to send a close frame to the client and then close the socket
* immediately. There is no point in waiting for a close frame from the
* client because there is no guarantee that we can recover from
* whatever messed up state the client put the connection into.
*/
wsSession.onClose(cr);
}
}

View File

@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.util.Map;
import javax.websocket.server.ServerEndpointConfig;
class WsMappingResult {
private final ServerEndpointConfig config;
private final Map<String,String> pathParams;
WsMappingResult(ServerEndpointConfig config,
Map<String,String> pathParams) {
this.config = config;
this.pathParams = pathParams;
}
ServerEndpointConfig getConfig() {
return config;
}
Map<String,String> getPathParams() {
return pathParams;
}
}

View File

@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.websocket.Decoder;
import javax.websocket.Encoder;
import javax.websocket.Extension;
import javax.websocket.server.ServerEndpointConfig;
/**
* Wraps the provided {@link ServerEndpointConfig} and provides a per session
* view - the difference being that the map returned by {@link
* #getUserProperties()} is unique to this instance rather than shared with the
* wrapped {@link ServerEndpointConfig}.
*/
class WsPerSessionServerEndpointConfig implements ServerEndpointConfig {
private final ServerEndpointConfig perEndpointConfig;
private final Map<String,Object> perSessionUserProperties =
new ConcurrentHashMap<>();
WsPerSessionServerEndpointConfig(ServerEndpointConfig perEndpointConfig) {
this.perEndpointConfig = perEndpointConfig;
perSessionUserProperties.putAll(perEndpointConfig.getUserProperties());
}
@Override
public List<Class<? extends Encoder>> getEncoders() {
return perEndpointConfig.getEncoders();
}
@Override
public List<Class<? extends Decoder>> getDecoders() {
return perEndpointConfig.getDecoders();
}
@Override
public Map<String,Object> getUserProperties() {
return perSessionUserProperties;
}
@Override
public Class<?> getEndpointClass() {
return perEndpointConfig.getEndpointClass();
}
@Override
public String getPath() {
return perEndpointConfig.getPath();
}
@Override
public List<String> getSubprotocols() {
return perEndpointConfig.getSubprotocols();
}
@Override
public List<Extension> getExtensions() {
return perEndpointConfig.getExtensions();
}
@Override
public Configurator getConfigurator() {
return perEndpointConfig.getConfigurator();
}
}

View File

@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.io.EOFException;
import java.io.IOException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.nio.channels.InterruptedByTimeoutException;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import javax.websocket.SendHandler;
import javax.websocket.SendResult;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.res.StringManager;
import nginx.unit.websocket.Transformation;
import nginx.unit.websocket.WsRemoteEndpointImplBase;
/**
* This is the server side {@link javax.websocket.RemoteEndpoint} implementation
* - i.e. what the server uses to send data to the client.
*/
public class WsRemoteEndpointImplServer extends WsRemoteEndpointImplBase {
private static final StringManager sm =
StringManager.getManager(WsRemoteEndpointImplServer.class);
private final Log log = LogFactory.getLog(WsRemoteEndpointImplServer.class); // must not be static
private volatile SendHandler handler = null;
private volatile ByteBuffer[] buffers = null;
private volatile long timeoutExpiry = -1;
private volatile boolean close;
public WsRemoteEndpointImplServer(
WsServerContainer serverContainer) {
}
@Override
protected final boolean isMasked() {
return false;
}
@Override
protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry,
ByteBuffer... buffers) {
}
@Override
protected void doClose() {
if (handler != null) {
// close() can be triggered by a wide range of scenarios. It is far
// simpler just to always use a dispatch than it is to try and track
// whether or not this method was called by the same thread that
// triggered the write
clearHandler(new EOFException(), true);
}
}
protected long getTimeoutExpiry() {
return timeoutExpiry;
}
/*
* Currently this is only called from the background thread so we could just
* call clearHandler() with useDispatch == false but the method parameter
* was added in case other callers started to use this method to make sure
* that those callers think through what the correct value of useDispatch is
* for them.
*/
protected void onTimeout(boolean useDispatch) {
if (handler != null) {
clearHandler(new SocketTimeoutException(), useDispatch);
}
close();
}
@Override
protected void setTransformation(Transformation transformation) {
// Overridden purely so it is visible to other classes in this package
super.setTransformation(transformation);
}
/**
*
* @param t The throwable associated with any error that
* occurred
* @param useDispatch Should {@link SendHandler#onResult(SendResult)} be
* called from a new thread, keeping in mind the
* requirements of
* {@link javax.websocket.RemoteEndpoint.Async}
*/
private void clearHandler(Throwable t, boolean useDispatch) {
// Setting the result marks this (partial) message as
// complete which means the next one may be sent which
// could update the value of the handler. Therefore, keep a
// local copy before signalling the end of the (partial)
// message.
SendHandler sh = handler;
handler = null;
buffers = null;
if (sh != null) {
if (useDispatch) {
OnResultRunnable r = new OnResultRunnable(sh, t);
} else {
if (t == null) {
sh.onResult(new SendResult());
} else {
sh.onResult(new SendResult(t));
}
}
}
}
private static class OnResultRunnable implements Runnable {
private final SendHandler sh;
private final Throwable t;
private OnResultRunnable(SendHandler sh, Throwable t) {
this.sh = sh;
this.t = t;
}
@Override
public void run() {
if (t == null) {
sh.onResult(new SendResult());
} else {
sh.onResult(new SendResult(t));
}
}
}
}

View File

@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.lang.reflect.Modifier;
import java.util.HashSet;
import java.util.Set;
import javax.servlet.ServletContainerInitializer;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.annotation.HandlesTypes;
import javax.websocket.ContainerProvider;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerApplicationConfig;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
/**
* Registers an interest in any class that is annotated with
* {@link ServerEndpoint} so that Endpoint can be published via the WebSocket
* server.
*/
@HandlesTypes({ServerEndpoint.class, ServerApplicationConfig.class,
Endpoint.class})
public class WsSci implements ServletContainerInitializer {
@Override
public void onStartup(Set<Class<?>> clazzes, ServletContext ctx)
throws ServletException {
WsServerContainer sc = init(ctx, true);
if (clazzes == null || clazzes.size() == 0) {
return;
}
// Group the discovered classes by type
Set<ServerApplicationConfig> serverApplicationConfigs = new HashSet<>();
Set<Class<? extends Endpoint>> scannedEndpointClazzes = new HashSet<>();
Set<Class<?>> scannedPojoEndpoints = new HashSet<>();
try {
// wsPackage is "javax.websocket."
String wsPackage = ContainerProvider.class.getName();
wsPackage = wsPackage.substring(0, wsPackage.lastIndexOf('.') + 1);
for (Class<?> clazz : clazzes) {
int modifiers = clazz.getModifiers();
if (!Modifier.isPublic(modifiers) ||
Modifier.isAbstract(modifiers)) {
// Non-public or abstract - skip it.
continue;
}
// Protect against scanning the WebSocket API JARs
if (clazz.getName().startsWith(wsPackage)) {
continue;
}
if (ServerApplicationConfig.class.isAssignableFrom(clazz)) {
serverApplicationConfigs.add(
(ServerApplicationConfig) clazz.getConstructor().newInstance());
}
if (Endpoint.class.isAssignableFrom(clazz)) {
@SuppressWarnings("unchecked")
Class<? extends Endpoint> endpoint =
(Class<? extends Endpoint>) clazz;
scannedEndpointClazzes.add(endpoint);
}
if (clazz.isAnnotationPresent(ServerEndpoint.class)) {
scannedPojoEndpoints.add(clazz);
}
}
} catch (ReflectiveOperationException e) {
throw new ServletException(e);
}
// Filter the results
Set<ServerEndpointConfig> filteredEndpointConfigs = new HashSet<>();
Set<Class<?>> filteredPojoEndpoints = new HashSet<>();
if (serverApplicationConfigs.isEmpty()) {
filteredPojoEndpoints.addAll(scannedPojoEndpoints);
} else {
for (ServerApplicationConfig config : serverApplicationConfigs) {
Set<ServerEndpointConfig> configFilteredEndpoints =
config.getEndpointConfigs(scannedEndpointClazzes);
if (configFilteredEndpoints != null) {
filteredEndpointConfigs.addAll(configFilteredEndpoints);
}
Set<Class<?>> configFilteredPojos =
config.getAnnotatedEndpointClasses(
scannedPojoEndpoints);
if (configFilteredPojos != null) {
filteredPojoEndpoints.addAll(configFilteredPojos);
}
}
}
try {
// Deploy endpoints
for (ServerEndpointConfig config : filteredEndpointConfigs) {
sc.addEndpoint(config);
}
// Deploy POJOs
for (Class<?> clazz : filteredPojoEndpoints) {
sc.addEndpoint(clazz);
}
} catch (DeploymentException e) {
throw new ServletException(e);
}
}
static WsServerContainer init(ServletContext servletContext,
boolean initBySciMechanism) {
WsServerContainer sc = new WsServerContainer(servletContext);
servletContext.setAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE, sc);
servletContext.addListener(new WsSessionListener(sc));
// Can't register the ContextListener again if the ContextListener is
// calling this method
if (initBySciMechanism) {
servletContext.addListener(new WsContextListener());
}
return sc;
}
}

View File

@@ -0,0 +1,470 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.DispatcherType;
import javax.servlet.FilterRegistration;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.DeploymentException;
import javax.websocket.Encoder;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import javax.websocket.server.ServerEndpointConfig.Configurator;
import org.apache.tomcat.InstanceManager;
import org.apache.tomcat.util.res.StringManager;
import nginx.unit.websocket.WsSession;
import nginx.unit.websocket.WsWebSocketContainer;
import nginx.unit.websocket.pojo.PojoMethodMapping;
/**
* Provides a per class loader (i.e. per web application) instance of a
* ServerContainer. Web application wide defaults may be configured by setting
* the following servlet context initialisation parameters to the desired
* values.
* <ul>
* <li>{@link Constants#BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li>
* <li>{@link Constants#TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li>
* </ul>
*/
public class WsServerContainer extends WsWebSocketContainer
implements ServerContainer {
private static final StringManager sm = StringManager.getManager(WsServerContainer.class);
private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED =
new CloseReason(CloseCodes.VIOLATED_POLICY,
"This connection was established under an authenticated " +
"HTTP session that has ended.");
private final ServletContext servletContext;
private final Map<String,ServerEndpointConfig> configExactMatchMap =
new ConcurrentHashMap<>();
private final Map<Integer,SortedSet<TemplatePathMatch>> configTemplateMatchMap =
new ConcurrentHashMap<>();
private volatile boolean enforceNoAddAfterHandshake =
nginx.unit.websocket.Constants.STRICT_SPEC_COMPLIANCE;
private volatile boolean addAllowed = true;
private final Map<String,Set<WsSession>> authenticatedSessions = new ConcurrentHashMap<>();
private volatile boolean endpointsRegistered = false;
WsServerContainer(ServletContext servletContext) {
this.servletContext = servletContext;
setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName()));
// Configure servlet context wide defaults
String value = servletContext.getInitParameter(
Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM);
if (value != null) {
setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value));
}
value = servletContext.getInitParameter(
Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM);
if (value != null) {
setDefaultMaxTextMessageBufferSize(Integer.parseInt(value));
}
value = servletContext.getInitParameter(
Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM);
if (value != null) {
setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value));
}
FilterRegistration.Dynamic fr = servletContext.addFilter(
"Tomcat WebSocket (JSR356) Filter", new WsFilter());
fr.setAsyncSupported(true);
EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST,
DispatcherType.FORWARD);
fr.addMappingForUrlPatterns(types, true, "/*");
}
/**
* Published the provided endpoint implementation at the specified path with
* the specified configuration. {@link #WsServerContainer(ServletContext)}
* must be called before calling this method.
*
* @param sec The configuration to use when creating endpoint instances
* @throws DeploymentException if the endpoint cannot be published as
* requested
*/
@Override
public void addEndpoint(ServerEndpointConfig sec)
throws DeploymentException {
if (enforceNoAddAfterHandshake && !addAllowed) {
throw new DeploymentException(
sm.getString("serverContainer.addNotAllowed"));
}
if (servletContext == null) {
throw new DeploymentException(
sm.getString("serverContainer.servletContextMissing"));
}
String path = sec.getPath();
// Add method mapping to user properties
PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(),
sec.getDecoders(), path);
if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null
|| methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) {
sec.getUserProperties().put(nginx.unit.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY,
methodMapping);
}
UriTemplate uriTemplate = new UriTemplate(path);
if (uriTemplate.hasParameters()) {
Integer key = Integer.valueOf(uriTemplate.getSegmentCount());
SortedSet<TemplatePathMatch> templateMatches =
configTemplateMatchMap.get(key);
if (templateMatches == null) {
// Ensure that if concurrent threads execute this block they
// both end up using the same TreeSet instance
templateMatches = new TreeSet<>(
TemplatePathMatchComparator.getInstance());
configTemplateMatchMap.putIfAbsent(key, templateMatches);
templateMatches = configTemplateMatchMap.get(key);
}
if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) {
// Duplicate uriTemplate;
throw new DeploymentException(
sm.getString("serverContainer.duplicatePaths", path,
sec.getEndpointClass(),
sec.getEndpointClass()));
}
} else {
// Exact match
ServerEndpointConfig old = configExactMatchMap.put(path, sec);
if (old != null) {
// Duplicate path mappings
throw new DeploymentException(
sm.getString("serverContainer.duplicatePaths", path,
old.getEndpointClass(),
sec.getEndpointClass()));
}
}
endpointsRegistered = true;
}
/**
* Provides the equivalent of {@link #addEndpoint(ServerEndpointConfig)}
* for publishing plain old java objects (POJOs) that have been annotated as
* WebSocket endpoints.
*
* @param pojo The annotated POJO
*/
@Override
public void addEndpoint(Class<?> pojo) throws DeploymentException {
ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class);
if (annotation == null) {
throw new DeploymentException(
sm.getString("serverContainer.missingAnnotation",
pojo.getName()));
}
String path = annotation.value();
// Validate encoders
validateEncoders(annotation.encoders());
// ServerEndpointConfig
ServerEndpointConfig sec;
Class<? extends Configurator> configuratorClazz =
annotation.configurator();
Configurator configurator = null;
if (!configuratorClazz.equals(Configurator.class)) {
try {
configurator = annotation.configurator().getConstructor().newInstance();
} catch (ReflectiveOperationException e) {
throw new DeploymentException(sm.getString(
"serverContainer.configuratorFail",
annotation.configurator().getName(),
pojo.getClass().getName()), e);
}
}
if (configurator == null) {
configurator = new nginx.unit.websocket.server.DefaultServerEndpointConfigurator();
}
sec = ServerEndpointConfig.Builder.create(pojo, path).
decoders(Arrays.asList(annotation.decoders())).
encoders(Arrays.asList(annotation.encoders())).
subprotocols(Arrays.asList(annotation.subprotocols())).
configurator(configurator).
build();
addEndpoint(sec);
}
boolean areEndpointsRegistered() {
return endpointsRegistered;
}
/**
* Until the WebSocket specification provides such a mechanism, this Tomcat
* proprietary method is provided to enable applications to programmatically
* determine whether or not to upgrade an individual request to WebSocket.
* <p>
* Note: This method is not used by Tomcat but is used directly by
* third-party code and must not be removed.
*
* @param request The request object to be upgraded
* @param response The response object to be populated with the result of
* the upgrade
* @param sec The server endpoint to use to process the upgrade request
* @param pathParams The path parameters associated with the upgrade request
*
* @throws ServletException If a configuration error prevents the upgrade
* from taking place
* @throws IOException If an I/O error occurs during the upgrade process
*/
public void doUpgrade(HttpServletRequest request,
HttpServletResponse response, ServerEndpointConfig sec,
Map<String,String> pathParams)
throws ServletException, IOException {
UpgradeUtil.doUpgrade(this, request, response, sec, pathParams);
}
public WsMappingResult findMapping(String path) {
// Prevent registering additional endpoints once the first attempt has
// been made to use one
if (addAllowed) {
addAllowed = false;
}
// Check an exact match. Simple case as there are no templates.
ServerEndpointConfig sec = configExactMatchMap.get(path);
if (sec != null) {
return new WsMappingResult(sec, Collections.<String, String>emptyMap());
}
// No exact match. Need to look for template matches.
UriTemplate pathUriTemplate = null;
try {
pathUriTemplate = new UriTemplate(path);
} catch (DeploymentException e) {
// Path is not valid so can't be matched to a WebSocketEndpoint
return null;
}
// Number of segments has to match
Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount());
SortedSet<TemplatePathMatch> templateMatches =
configTemplateMatchMap.get(key);
if (templateMatches == null) {
// No templates with an equal number of segments so there will be
// no matches
return null;
}
// List is in alphabetical order of normalised templates.
// Correct match is the first one that matches.
Map<String,String> pathParams = null;
for (TemplatePathMatch templateMatch : templateMatches) {
pathParams = templateMatch.getUriTemplate().match(pathUriTemplate);
if (pathParams != null) {
sec = templateMatch.getConfig();
break;
}
}
if (sec == null) {
// No match
return null;
}
return new WsMappingResult(sec, pathParams);
}
public boolean isEnforceNoAddAfterHandshake() {
return enforceNoAddAfterHandshake;
}
public void setEnforceNoAddAfterHandshake(
boolean enforceNoAddAfterHandshake) {
this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake;
}
/**
* {@inheritDoc}
*
* Overridden to make it visible to other classes in this package.
*/
@Override
protected void registerSession(Endpoint endpoint, WsSession wsSession) {
super.registerSession(endpoint, wsSession);
if (wsSession.isOpen() &&
wsSession.getUserPrincipal() != null &&
wsSession.getHttpSessionId() != null) {
registerAuthenticatedSession(wsSession,
wsSession.getHttpSessionId());
}
}
/**
* {@inheritDoc}
*
* Overridden to make it visible to other classes in this package.
*/
@Override
protected void unregisterSession(Endpoint endpoint, WsSession wsSession) {
if (wsSession.getUserPrincipal() != null &&
wsSession.getHttpSessionId() != null) {
unregisterAuthenticatedSession(wsSession,
wsSession.getHttpSessionId());
}
super.unregisterSession(endpoint, wsSession);
}
private void registerAuthenticatedSession(WsSession wsSession,
String httpSessionId) {
Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId);
if (wsSessions == null) {
wsSessions = Collections.newSetFromMap(
new ConcurrentHashMap<WsSession,Boolean>());
authenticatedSessions.putIfAbsent(httpSessionId, wsSessions);
wsSessions = authenticatedSessions.get(httpSessionId);
}
wsSessions.add(wsSession);
}
private void unregisterAuthenticatedSession(WsSession wsSession,
String httpSessionId) {
Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId);
// wsSessions will be null if the HTTP session has ended
if (wsSessions != null) {
wsSessions.remove(wsSession);
}
}
public void closeAuthenticatedSession(String httpSessionId) {
Set<WsSession> wsSessions = authenticatedSessions.remove(httpSessionId);
if (wsSessions != null && !wsSessions.isEmpty()) {
for (WsSession wsSession : wsSessions) {
try {
wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED);
} catch (IOException e) {
// Any IOExceptions during close will have been caught and the
// onError method called.
}
}
}
}
private static void validateEncoders(Class<? extends Encoder>[] encoders)
throws DeploymentException {
for (Class<? extends Encoder> encoder : encoders) {
// Need to instantiate decoder to ensure it is valid and that
// deployment can be failed if it is not
@SuppressWarnings("unused")
Encoder instance;
try {
encoder.getConstructor().newInstance();
} catch(ReflectiveOperationException e) {
throw new DeploymentException(sm.getString(
"serverContainer.encoderFail", encoder.getName()), e);
}
}
}
private static class TemplatePathMatch {
private final ServerEndpointConfig config;
private final UriTemplate uriTemplate;
public TemplatePathMatch(ServerEndpointConfig config,
UriTemplate uriTemplate) {
this.config = config;
this.uriTemplate = uriTemplate;
}
public ServerEndpointConfig getConfig() {
return config;
}
public UriTemplate getUriTemplate() {
return uriTemplate;
}
}
/**
* This Comparator implementation is thread-safe so only create a single
* instance.
*/
private static class TemplatePathMatchComparator
implements Comparator<TemplatePathMatch> {
private static final TemplatePathMatchComparator INSTANCE =
new TemplatePathMatchComparator();
public static TemplatePathMatchComparator getInstance() {
return INSTANCE;
}
private TemplatePathMatchComparator() {
// Hide default constructor
}
@Override
public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) {
return tpm1.getUriTemplate().getNormalizedPath().compareTo(
tpm2.getUriTemplate().getNormalizedPath());
}
}
}

View File

@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import javax.servlet.http.HttpSessionEvent;
import javax.servlet.http.HttpSessionListener;
public class WsSessionListener implements HttpSessionListener{
private final WsServerContainer wsServerContainer;
public WsSessionListener(WsServerContainer wsServerContainer) {
this.wsServerContainer = wsServerContainer;
}
@Override
public void sessionDestroyed(HttpSessionEvent se) {
wsServerContainer.closeAuthenticatedSession(se.getSession().getId());
}
}

View File

@@ -0,0 +1,128 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket.server;
import java.util.Comparator;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.AtomicInteger;
import nginx.unit.websocket.BackgroundProcess;
import nginx.unit.websocket.BackgroundProcessManager;
/**
* Provides timeouts for asynchronous web socket writes. On the server side we
* only have access to {@link javax.servlet.ServletOutputStream} and
* {@link javax.servlet.ServletInputStream} so there is no way to set a timeout
* for writes to the client.
*/
public class WsWriteTimeout implements BackgroundProcess {
private final Set<WsRemoteEndpointImplServer> endpoints =
new ConcurrentSkipListSet<>(new EndpointComparator());
private final AtomicInteger count = new AtomicInteger(0);
private int backgroundProcessCount = 0;
private volatile int processPeriod = 1;
@Override
public void backgroundProcess() {
// This method gets called once a second.
backgroundProcessCount ++;
if (backgroundProcessCount >= processPeriod) {
backgroundProcessCount = 0;
long now = System.currentTimeMillis();
for (WsRemoteEndpointImplServer endpoint : endpoints) {
if (endpoint.getTimeoutExpiry() < now) {
// Background thread, not the thread that triggered the
// write so no need to use a dispatch
endpoint.onTimeout(false);
} else {
// Endpoints are ordered by timeout expiry so if this point
// is reached there is no need to check the remaining
// endpoints
break;
}
}
}
}
@Override
public void setProcessPeriod(int period) {
this.processPeriod = period;
}
/**
* {@inheritDoc}
*
* The default value is 1 which means asynchronous write timeouts are
* processed every 1 second.
*/
@Override
public int getProcessPeriod() {
return processPeriod;
}
public void register(WsRemoteEndpointImplServer endpoint) {
boolean result = endpoints.add(endpoint);
if (result) {
int newCount = count.incrementAndGet();
if (newCount == 1) {
BackgroundProcessManager.getInstance().register(this);
}
}
}
public void unregister(WsRemoteEndpointImplServer endpoint) {
boolean result = endpoints.remove(endpoint);
if (result) {
int newCount = count.decrementAndGet();
if (newCount == 0) {
BackgroundProcessManager.getInstance().unregister(this);
}
}
}
/**
* Note: this comparator imposes orderings that are inconsistent with equals
*/
private static class EndpointComparator implements
Comparator<WsRemoteEndpointImplServer> {
@Override
public int compare(WsRemoteEndpointImplServer o1,
WsRemoteEndpointImplServer o2) {
long t1 = o1.getTimeoutExpiry();
long t2 = o2.getTimeoutExpiry();
if (t1 < t2) {
return -1;
} else if (t1 == t2) {
return 0;
} else {
return 1;
}
}
}
}

View File

@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Server-side specific implementation classes. These are in a separate package
* to make packaging a pure client JAR simpler.
*/
package nginx.unit.websocket.server;