/** * Copyright 2009 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS-IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.appjet.ajstdlib; import scala.collection.mutable.{Queue, HashMap, SynchronizedMap, ArrayBuffer}; import javax.servlet.http.{HttpServletRequest, HttpServletResponse, HttpServlet}; import org.mortbay.jetty.servlet.{ServletHolder, Context}; import org.mortbay.jetty.{HttpConnection, Handler, RetryRequest}; import org.mortbay.jetty.nio.SelectChannelConnector; import org.mortbay.io.nio.SelectChannelEndPoint; import org.mortbay.util.ajax.{ContinuationSupport, Continuation}; import java.util.{Timer, TimerTask}; import java.lang.ref.WeakReference; import org.mozilla.javascript.{Context => JSContext, Scriptable}; import net.appjet.oui._; import net.appjet.oui.Util.enumerationToRichEnumeration; import net.appjet.common.util.HttpServletRequestFactory; trait SocketConnectionHandler { def message(sender: StreamingSocket, data: String, req: HttpServletRequest); def connect(socket: StreamingSocket, req: HttpServletRequest); def disconnect(socket: StreamingSocket, req: HttpServletRequest); } object SocketManager { val sockets = new HashMap[String, StreamingSocket] with SynchronizedMap[String, StreamingSocket]; val handler = new SocketConnectionHandler { val cometLib = new FixedDiskLibrary(new SpecialJarOrNotFile(config.ajstdlibHome, "oncomet.js")); def cometExecutable = cometLib.executable; def message(socket: StreamingSocket, data: String, req: HttpServletRequest) { val t1 = profiler.time; // println("Message from: "+socket.id+": "+data); val runner = ScopeReuseManager.getRunner; val ec = ExecutionContext(new RequestWrapper(req), new ResponseWrapper(null), runner); ec.attributes("cometOperation") = "message"; ec.attributes("cometId") = socket.id; ec.attributes("cometData") = data; ec.attributes("cometSocket") = socket; net.appjet.oui.execution.execute( ec, (sc: Int, msg: String) => throw new HandlerException(sc, msg, null), () => {}, () => { ScopeReuseManager.freeRunner(runner); }, Some(cometExecutable)); cometlatencies.register(((profiler.time-t1)/1000).toInt); } def connect(socket: StreamingSocket, req: HttpServletRequest) { // println("Connect on: "+socket); val runner = ScopeReuseManager.getRunner; val ec = ExecutionContext(new RequestWrapper(req), new ResponseWrapper(null), runner); ec.attributes("cometOperation") = "connect"; ec.attributes("cometId") = socket.id; ec.attributes("cometSocket") = socket; net.appjet.oui.execution.execute( ec, (sc: Int, msg: String) => throw new HandlerException(sc, msg, null), () => {}, () => { ScopeReuseManager.freeRunner(runner); }, Some(cometExecutable)); } def disconnect(socket: StreamingSocket, req: HttpServletRequest) { val toRun = new Runnable { def run() { val runner = ScopeReuseManager.getRunner; val ec = ExecutionContext(new RequestWrapper(req), new ResponseWrapper(null), runner); ec.attributes("cometOperation") = "disconnect"; ec.attributes("cometId") = socket.id; ec.attributes("cometSocket") = socket; net.appjet.oui.execution.execute( ec, (sc: Int, msg: String) => throw new HandlerException(sc, msg, null), () => {}, () => { ScopeReuseManager.freeRunner(runner); }, Some(cometExecutable)); } } main.server.getThreadPool().dispatch(toRun); } } def apply(id: String, create: Boolean) = { if (create) { Some(sockets.getOrElseUpdate(id, new StreamingSocket(id, handler))); } else { if (id == null) error("bad id: "+id); sockets.get(id); } } class HandlerException(val sc: Int, val msg: String, val cause: Exception) extends RuntimeException("An error occurred while handling a request: "+sc+" - "+msg, cause); } // And this would be the javascript interface. Whee. object Comet extends CometSupport.CometHandler { def init() { CometSupport.cometHandler = this; context.start(); } val acceptableTransports = { val t = new ArrayBuffer[String]; if (! config.disableShortPolling) { t += "shortpolling"; } if (config.transportUseWildcardSubdomains) { t += "longpolling"; } t += "streaming"; t.mkString("['", "', '", "']"); } val servlet = new StreamingSocketServlet(); val holder = new ServletHolder(servlet); val context = new Context(null, "/", Context.NO_SESSIONS | Context.NO_SECURITY); context.addServlet(holder, "/*"); context.setMaxFormContentSize(1024*1024); def handleCometRequest(req: HttpServletRequest, res: HttpServletResponse) { context.handle(req.getRequestURI().substring(config.transportPrefix.length), req, res, Handler.FORWARD); } lazy val ccLib = new FixedDiskResource(new JarOrNotFile(config.ajstdlibHome, "streaming-client.js") { override val classBase = "/net/appjet/ajstdlib/"; override val fileSep = "/../../net.appjet.ajstdlib/"; }); def clientCode(contextPath: String, acceptableChannelTypes: String) = { ccLib.contents.replaceAll("%contextPath%", contextPath).replaceAll("\"%acceptableChannelTypes%\"", acceptableChannelTypes).replaceAll("\"%canUseSubdomains%\"", if (config.transportUseWildcardSubdomains) "true" else "false"); } def clientMTime = ccLib.fileLastModified; lazy val ccFrame = new FixedDiskResource(new JarOrNotFile(config.ajstdlibHome, "streaming-iframe.html") { override val classBase = "/net/appjet/ajstdlib/"; override val fileSep = "/../../net.appjet.ajstdlib/"; }); def frameCode = { if (! config.devMode) ccFrame.contents.replace("\n "; if (txt.length < 256) String.format("%256s", txt); else txt; } def header(req: HttpServletRequest) = { val document_domain = "\""+req.getHeader("Host").split("\\.").slice(2).mkString(".").split(":")(0)+"\""; """ f"; // " - damn textmate mode! } override def sendRestartFailure(ec: ExecutionContext) { ec.response.write(header(ec.request.req)); ec.response.write(controlMessage("restart-fail")); } override def handleNewConnection(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder) { super.handleNewConnection(req, res, out); res.setContentType("text/html"); out.append(header(req)); } } trait OperaChannel extends StreamingChannel { override def wrapBody(msgBody: String) = { "Event: message\ndata: "+msgBody+"\n\n"; } override def handleNewConnection(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder) { super.handleNewConnection(req, res, out); res.setContentType("application/x-dom-event-stream"); } } class StreamingChannel(val socket: StreamingSocket) extends Channel with XhrChannel { def kind = ChannelType.Streaming; var c: Option[SelectChannelConnector.RetryContinuation] = None; var doClose = false; def messageWaiting() { main.server.getThreadPool().dispatch(new Runnable() { def run() { socket.synchronized { c.filter(_.isPending()).foreach(_.resume()); } } }); } def setSequenceNumberIfAppropriate(req: HttpServletRequest) { if (c.get.isNew) { val lastReceivedSeq = java.lang.Integer.parseInt(req.getParameter("seq")); socket.updateConfirmedSeqNumber(this, lastReceivedSeq); } } def sendHandshake(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder) { out.append(controlMessage("ok")); } def sendUnconfirmedMessages(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder) { for (msg <- socket.getUnconfirmedMessages(this)) { out.append(wireFormat(msg)); } } def sendWaitingMessages(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder) { var msg = socket.getWaitingMessage(this); while (msg.isDefined) { out.append(wireFormat(msg.get)); msg = socket.getWaitingMessage(this); } } def handleUnexpectedDisconnect(req: HttpServletRequest, res: HttpServletResponse, ep: KnowsAboutDispatch) { socket.synchronized { socket.hiccup(this); } ep.close(); } def writeAndFlush(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder, ep: KnowsAboutDispatch) { // println("Writing to "+socket.id+": "+out.toString); res.getWriter.print(out.toString); res.getWriter.flush(); } def suspendIfNecessary(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder, ep: KnowsAboutDispatch) { scheduleKeepalive(50*1000); ep.undispatch(); c.get.suspend(0); } def sendKeepaliveIfNecessary(out: StringBuilder, sendKeepalive: Boolean) { if (out.length == 0 && sendKeepalive) { out.append(controlMessage("keepalive")); } } def shouldHandshake(req: HttpServletRequest, res: HttpServletResponse) = c.get.isNew; var sendKeepalive = false; var keepaliveTask: TimerTask = null; def scheduleKeepalive(timeout: Int) { if (keepaliveTask != null) { keepaliveTask.cancel(); } val p = new WeakReference(this); keepaliveTask = new TimerTask { def run() { val channel = p.get(); if (channel != null) { channel.synchronized { channel.sendKeepalive = true; channel.messageWaiting(); } } } } Channels.timer.schedule(keepaliveTask, timeout); } def handleNewConnection(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder) { req.setAttribute("StreamingSocketServlet_channel", this); res.setHeader("Connection", "close"); for ((k, v) <- Util.noCacheHeaders) { res.setHeader(k, v); } // maybe this will help with proxies? res.setContentType("text/messages; charset=utf-8"); } def handle(req: HttpServletRequest, res: HttpServletResponse) { val ec = req.getAttribute("executionContext").asInstanceOf[ExecutionContext]; val ep = HttpConnection.getCurrentConnection.getEndPoint.asInstanceOf[KnowsAboutDispatch]; val out = new StringBuilder; try { socket.synchronized { val sendKeepaliveNow = sendKeepalive; sendKeepalive = false; if (keepaliveTask != null) { keepaliveTask.cancel(); keepaliveTask = null; } c = Some(ContinuationSupport.getContinuation(req, socket).asInstanceOf[SelectChannelConnector.RetryContinuation]); setSequenceNumberIfAppropriate(req); if (doClose) { ep.close(); return; } if (c.get.isNew) { handleNewConnection(req, res, out); } else { c.get.suspend(-1); if (ep.isDispatched) { handleUnexpectedDisconnect(req, res, ep); return; } } if (shouldHandshake(req, res)) { // println("new stream request: "+socket.id); sendHandshake(req, res, out); sendUnconfirmedMessages(req, res, out); } sendWaitingMessages(req, res, out); sendKeepaliveIfNecessary(out, sendKeepaliveNow); suspendIfNecessary(req, res, out, ep); } } finally { writeAndFlush(req, res, out, ep); } } def close() { doClose = true; messageWaiting(); } def isConnected = ! doClose; } class LongPollingChannel(socket: StreamingSocket) extends StreamingChannel(socket) { // println("creating longpoll!"); override def kind = ChannelType.LongPolling; override def shouldHandshake(req: HttpServletRequest, res: HttpServletResponse) = req.getParameter("new") == "yes"; override def sendHandshake(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder) { // println("sending handshake"); out.append(controlMessage("ok")); } override def suspendIfNecessary(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder, ep: KnowsAboutDispatch) { if (out.length == 0) { // println("suspending longpoll: "+socket.id); val to = java.lang.Integer.parseInt(req.getParameter("timeout")); // println("LongPoll scheduling keepalive for: "+to); scheduleKeepalive(to); ep.undispatch(); c.get.suspend(0); } } override def writeAndFlush(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder, ep: KnowsAboutDispatch) { if (out.length > 0) { // println("Writing to "+socket.id+": "+out.toString); // println("writing and flushing longpoll") val ec = req.getAttribute("executionContext").asInstanceOf[ExecutionContext]; for ((k, v) <- Util.noCacheHeaders) { ec.response.setHeader(k, v); } // maybe this will help with proxies? // println("writing: "+out); ec.response.write(out.toString); socket.synchronized { socket.hiccup(this); c = None; } } } override def handleNewConnection(req: HttpServletRequest, res: HttpServletResponse, out: StringBuilder) { socket.revive(this); req.setAttribute("StreamingSocketServlet_channel", this); } override def isConnected = socket.synchronized { c.isDefined; } } class StreamingSocketServlet extends HttpServlet { val version = 2; override def doGet(req: HttpServletRequest, res: HttpServletResponse) { // describeRequest(req); val ec = req.getAttribute("executionContext").asInstanceOf[ExecutionContext]; try { if (req.getPathInfo() == "/js/client.js") { val contextPath = config.transportPrefix; val acceptableTransports = Comet.acceptableTransports; ec.response.setContentType("application/x-javascript"); ec.response.write(Comet.clientCode(contextPath, acceptableTransports)); } else if (req.getPathInfo() == "/xhrXdFrame") { ec.response.setContentType("text/html; charset=utf-8"); ec.response.write(Comet.frameCode); } else { val v = req.getParameter("v"); if (v == null || java.lang.Integer.parseInt(v) != version) { res.sendError(HttpServletResponse.SC_BAD_REQUEST, "bad version number!"); return; } val existingChannel = req.getAttribute("StreamingSocketServlet_channel"); if (existingChannel != null) { existingChannel.asInstanceOf[Channel].handle(req, res); } else { val socketId = req.getParameter("id"); val channelType = req.getParameter("channel"); val isNew = req.getParameter("new") == "yes"; val shouldCreateSocket = req.getParameter("create") == "yes"; val subType = req.getParameter("type"); val channel = SocketManager(socketId, shouldCreateSocket).map(_.channel(channelType, isNew, subType)).getOrElse(None); if (channel.isDefined) { channel.get.handle(req, res); } else { streaminglog(Map( "type" -> "event", "event" -> "restart-failure", "connection" -> socketId)); val failureChannel = ChannelType.valueOf(channelType).map(Channels.createNew(_, null, subType)); if (failureChannel.isDefined) { failureChannel.get.sendRestartFailure(ec); } else { ec.response.setStatusCode(HttpServletResponse.SC_NOT_FOUND); ec.response.write("So such socket, and/or unknown channel type: "+channelType); } } } } } catch { case e: RetryRequest => throw e; case t: Throwable => { exceptionlog("A comet error occurred: "); exceptionlog(t); ec.response.setStatusCode(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); ec.response.write(t.getMessage()); } } } def describeRequest(req: HttpServletRequest) { println(req.getMethod+" on "+req.getRequestURI()+"?"+req.getQueryString()); for (pname <- req.getParameterNames.asInstanceOf[java.util.Enumeration[String]]) { println(" "+pname+" -> "+req.getParameterValues(pname).mkString("[", ",", "]")); } } override def doPost(req: HttpServletRequest, res: HttpServletResponse) { val v = req.getParameter("v"); if (v == null || java.lang.Integer.parseInt(v) != version) { res.sendError(HttpServletResponse.SC_BAD_REQUEST, "bad version number!"); return; } val ec = req.getAttribute("executionContext").asInstanceOf[ExecutionContext]; val socketId = req.getParameter("id"); val socket = SocketManager(socketId, false); // describeRequest(req); if (socket.isEmpty) { ec.response.write("restart-fail"); streaminglog(Map( "type" -> "event", "event" -> "restart-failure", "connection" -> socketId)); // println("socket restart-fail: "+socketId); } else { val seq = java.lang.Integer.parseInt(req.getParameter("seq")); socket.get.updateConfirmedSeqNumber(null, seq); val messages = req.getParameterValues("m"); val controlMessages = req.getParameterValues("oob"); try { if (messages != null) for (msg <- messages) socket.get.receiveMessage(msg, req); if (controlMessages != null) for (msg <- controlMessages) { // println("Control message from "+socket.get.id+": "+msg); msg match { case "hiccup" => { streaminglog(Map( "type" -> "event", "event" -> "hiccup", "connection" -> socketId)); socket.get.prepareForReconnect(); } case _ => { if (msg.startsWith("useChannel")) { val msgParts = msg.split(":"); socket.get.useChannel(java.lang.Integer.parseInt(msgParts(1)), msgParts(2), req); } else if (msg.startsWith("kill")) { socket.get.kill("client request: "+msg.substring(Math.min(msg.length, "kill:".length))); } else { streaminglog(Map( "type" -> "error", "error" -> "unknown control message", "connection" -> socketId, "message" -> msg)); } } } } ec.response.write("ok"); } catch { case e: SocketManager.HandlerException => { exceptionlog(e); ec.response.setStatusCode(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); ec.response.write(e.getMessage()); // log these? } case t: Throwable => { // shouldn't happen... exceptionlog(t); ec.response.setStatusCode(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); ec.response.write(t.getMessage()); } } } } }