001package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket; 002 003/* 004 * #%L 005 * HAPI FHIR Subscription Server 006 * %% 007 * Copyright (C) 2014 - 2020 University Health Network 008 * %% 009 * Licensed under the Apache License, Version 2.0 (the "License"); 010 * you may not use this file except in compliance with the License. 011 * You may obtain a copy of the License at 012 * 013 * http://www.apache.org/licenses/LICENSE-2.0 014 * 015 * Unless required by applicable law or agreed to in writing, software 016 * distributed under the License is distributed on an "AS IS" BASIS, 017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 018 * See the License for the specific language governing permissions and 019 * limitations under the License. 020 * #L% 021 */ 022 023import ca.uhn.fhir.context.FhirContext; 024import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription; 025import ca.uhn.fhir.jpa.subscription.module.channel.SubscriptionChannelRegistry; 026import ca.uhn.fhir.jpa.subscription.module.channel.SubscriptionChannelWithHandlers; 027import ca.uhn.fhir.jpa.subscription.module.subscriber.ResourceDeliveryMessage; 028import org.hl7.fhir.instance.model.api.IIdType; 029import org.hl7.fhir.r4.model.IdType; 030import org.slf4j.Logger; 031import org.slf4j.LoggerFactory; 032import org.springframework.beans.factory.annotation.Autowired; 033import org.springframework.messaging.Message; 034import org.springframework.messaging.MessageHandler; 035import org.springframework.messaging.MessagingException; 036import org.springframework.web.socket.CloseStatus; 037import org.springframework.web.socket.TextMessage; 038import org.springframework.web.socket.WebSocketHandler; 039import org.springframework.web.socket.WebSocketSession; 040import org.springframework.web.socket.handler.TextWebSocketHandler; 041 042import javax.annotation.PostConstruct; 043import javax.annotation.PreDestroy; 044import java.io.IOException; 045 046public class SubscriptionWebsocketHandler extends TextWebSocketHandler implements WebSocketHandler { 047 private static Logger ourLog = LoggerFactory.getLogger(SubscriptionWebsocketHandler.class); 048 @Autowired 049 protected WebsocketConnectionValidator myWebsocketConnectionValidator; 050 @Autowired 051 SubscriptionChannelRegistry mySubscriptionChannelRegistry; 052 053 @Autowired 054 private FhirContext myCtx; 055 056 private IState myState = new InitialState(); 057 058 @Override 059 public void afterConnectionClosed(WebSocketSession theSession, CloseStatus theStatus) throws Exception { 060 super.afterConnectionClosed(theSession, theStatus); 061 ourLog.info("Closing WebSocket connection from {}", theSession.getRemoteAddress()); 062 } 063 064 @Override 065 public void afterConnectionEstablished(WebSocketSession theSession) throws Exception { 066 super.afterConnectionEstablished(theSession); 067 ourLog.info("Incoming WebSocket connection from {}", theSession.getRemoteAddress()); 068 } 069 070 protected void handleFailure(Exception theE) { 071 ourLog.error("Failure during communication", theE); 072 } 073 074 @Override 075 protected void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) throws Exception { 076 ourLog.info("Textmessage: " + theMessage.getPayload()); 077 myState.handleTextMessage(theSession, theMessage); 078 } 079 080 @Override 081 public void handleTransportError(WebSocketSession theSession, Throwable theException) throws Exception { 082 super.handleTransportError(theSession, theException); 083 ourLog.error("Transport error", theException); 084 } 085 086 @PostConstruct 087 public synchronized void postConstruct() { 088 ourLog.info("Websocket connection has been created"); 089 } 090 091 @PreDestroy 092 public synchronized void preDescroy() { 093 ourLog.info("Websocket connection is closing"); 094 IState state = myState; 095 if (state != null) { 096 state.closing(); 097 } 098 } 099 100 101 private interface IState { 102 103 void closing(); 104 105 void handleTextMessage(WebSocketSession theSession, TextMessage theMessage); 106 107 } 108 109 private class BoundStaticSubscriptionState implements IState, MessageHandler { 110 111 private final WebSocketSession mySession; 112 private final ActiveSubscription myActiveSubscription; 113 114 public BoundStaticSubscriptionState(WebSocketSession theSession, ActiveSubscription theActiveSubscription) { 115 mySession = theSession; 116 myActiveSubscription = theActiveSubscription; 117 118 SubscriptionChannelWithHandlers subscriptionChannelWithHandlers = mySubscriptionChannelRegistry.get(theActiveSubscription.getChannelName()); 119 subscriptionChannelWithHandlers.addHandler(this); 120 } 121 122 @Override 123 public void closing() { 124 SubscriptionChannelWithHandlers subscriptionChannelWithHandlers = mySubscriptionChannelRegistry.get(myActiveSubscription.getChannelName()); 125 subscriptionChannelWithHandlers.removeHandler(this); 126 } 127 128 private void deliver() { 129 try { 130 String payload = "ping " + myActiveSubscription.getId(); 131 ourLog.info("Sending WebSocket message: {}", payload); 132 mySession.sendMessage(new TextMessage(payload)); 133 } catch (IOException e) { 134 handleFailure(e); 135 } 136 } 137 138 @Override 139 public void handleMessage(Message<?> theMessage) { 140 if (!(theMessage.getPayload() instanceof ResourceDeliveryMessage)) { 141 return; 142 } 143 try { 144 ResourceDeliveryMessage msg = (ResourceDeliveryMessage) theMessage.getPayload(); 145 if (myActiveSubscription.getSubscription().equals(msg.getSubscription())) { 146 deliver(); 147 } 148 } catch (Exception e) { 149 ourLog.error("Failure handling subscription payload", e); 150 throw new MessagingException(theMessage, "Failure handling subscription payload", e); 151 } 152 } 153 154 @Override 155 public void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) { 156 try { 157 theSession.sendMessage(new TextMessage("Unexpected client message: " + theMessage.getPayload())); 158 } catch (IOException e) { 159 handleFailure(e); 160 } 161 } 162 } 163 164 private class InitialState implements IState { 165 166 private IIdType bindSimple(WebSocketSession theSession, String theBindString) { 167 IdType id = new IdType(theBindString); 168 169 WebsocketValidationResponse response = myWebsocketConnectionValidator.validate(id); 170 if (!response.isValid()) { 171 try { 172 ourLog.warn(response.getMessage()); 173 theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), response.getMessage())); 174 } catch (IOException e) { 175 handleFailure(e); 176 } 177 return null; 178 } 179 180 myState = new BoundStaticSubscriptionState(theSession, response.getActiveSubscription()); 181 182 return id; 183 } 184 185 @Override 186 public void closing() { 187 // nothing 188 } 189 190 @Override 191 public void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) { 192 String message = theMessage.getPayload(); 193 if (message.startsWith("bind ")) { 194 String remaining = message.substring("bind ".length()); 195 196 IIdType subscriptionId; 197 subscriptionId = bindSimple(theSession, remaining); 198 if (subscriptionId == null) { 199 return; 200 } 201 202 try { 203 theSession.sendMessage(new TextMessage("bound " + subscriptionId.getIdPart())); 204 } catch (IOException e) { 205 handleFailure(e); 206 } 207 208 } 209 } 210 211 } 212 213}