001package ca.uhn.fhir.jpa.subscription.module.subscriber.websocket;
002
003/*
004 * #%L
005 * HAPI FHIR Subscription Server
006 * %%
007 * Copyright (C) 2014 - 2020 University Health Network
008 * %%
009 * Licensed under the Apache License, Version 2.0 (the "License");
010 * you may not use this file except in compliance with the License.
011 * You may obtain a copy of the License at
012 *
013 *      http://www.apache.org/licenses/LICENSE-2.0
014 *
015 * Unless required by applicable law or agreed to in writing, software
016 * distributed under the License is distributed on an "AS IS" BASIS,
017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
018 * See the License for the specific language governing permissions and
019 * limitations under the License.
020 * #L%
021 */
022
023import ca.uhn.fhir.context.FhirContext;
024import ca.uhn.fhir.jpa.subscription.module.cache.ActiveSubscription;
025import ca.uhn.fhir.jpa.subscription.module.channel.SubscriptionChannelRegistry;
026import ca.uhn.fhir.jpa.subscription.module.channel.SubscriptionChannelWithHandlers;
027import ca.uhn.fhir.jpa.subscription.module.subscriber.ResourceDeliveryMessage;
028import org.hl7.fhir.instance.model.api.IIdType;
029import org.hl7.fhir.r4.model.IdType;
030import org.slf4j.Logger;
031import org.slf4j.LoggerFactory;
032import org.springframework.beans.factory.annotation.Autowired;
033import org.springframework.messaging.Message;
034import org.springframework.messaging.MessageHandler;
035import org.springframework.messaging.MessagingException;
036import org.springframework.web.socket.CloseStatus;
037import org.springframework.web.socket.TextMessage;
038import org.springframework.web.socket.WebSocketHandler;
039import org.springframework.web.socket.WebSocketSession;
040import org.springframework.web.socket.handler.TextWebSocketHandler;
041
042import javax.annotation.PostConstruct;
043import javax.annotation.PreDestroy;
044import java.io.IOException;
045
046public class SubscriptionWebsocketHandler extends TextWebSocketHandler implements WebSocketHandler {
047        private static Logger ourLog = LoggerFactory.getLogger(SubscriptionWebsocketHandler.class);
048        @Autowired
049        protected WebsocketConnectionValidator myWebsocketConnectionValidator;
050        @Autowired
051        SubscriptionChannelRegistry mySubscriptionChannelRegistry;
052
053        @Autowired
054        private FhirContext myCtx;
055
056        private IState myState = new InitialState();
057
058        @Override
059        public void afterConnectionClosed(WebSocketSession theSession, CloseStatus theStatus) throws Exception {
060                super.afterConnectionClosed(theSession, theStatus);
061                ourLog.info("Closing WebSocket connection from {}", theSession.getRemoteAddress());
062        }
063
064        @Override
065        public void afterConnectionEstablished(WebSocketSession theSession) throws Exception {
066                super.afterConnectionEstablished(theSession);
067                ourLog.info("Incoming WebSocket connection from {}", theSession.getRemoteAddress());
068        }
069
070        protected void handleFailure(Exception theE) {
071                ourLog.error("Failure during communication", theE);
072        }
073
074        @Override
075        protected void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) throws Exception {
076                ourLog.info("Textmessage: " + theMessage.getPayload());
077                myState.handleTextMessage(theSession, theMessage);
078        }
079
080        @Override
081        public void handleTransportError(WebSocketSession theSession, Throwable theException) throws Exception {
082                super.handleTransportError(theSession, theException);
083                ourLog.error("Transport error", theException);
084        }
085
086        @PostConstruct
087        public synchronized void postConstruct() {
088                ourLog.info("Websocket connection has been created");
089        }
090
091        @PreDestroy
092        public synchronized void preDescroy() {
093                ourLog.info("Websocket connection is closing");
094                IState state = myState;
095                if (state != null) {
096                        state.closing();
097                }
098        }
099
100
101        private interface IState {
102
103                void closing();
104
105                void handleTextMessage(WebSocketSession theSession, TextMessage theMessage);
106
107        }
108
109        private class BoundStaticSubscriptionState implements IState, MessageHandler {
110
111                private final WebSocketSession mySession;
112                private final ActiveSubscription myActiveSubscription;
113
114                public BoundStaticSubscriptionState(WebSocketSession theSession, ActiveSubscription theActiveSubscription) {
115                        mySession = theSession;
116                        myActiveSubscription = theActiveSubscription;
117
118                        SubscriptionChannelWithHandlers subscriptionChannelWithHandlers = mySubscriptionChannelRegistry.get(theActiveSubscription.getChannelName());
119                        subscriptionChannelWithHandlers.addHandler(this);
120                }
121
122                @Override
123                public void closing() {
124                        SubscriptionChannelWithHandlers subscriptionChannelWithHandlers = mySubscriptionChannelRegistry.get(myActiveSubscription.getChannelName());
125                        subscriptionChannelWithHandlers.removeHandler(this);
126                }
127
128                private void deliver() {
129                        try {
130                                String payload = "ping " + myActiveSubscription.getId();
131                                ourLog.info("Sending WebSocket message: {}", payload);
132                                mySession.sendMessage(new TextMessage(payload));
133                        } catch (IOException e) {
134                                handleFailure(e);
135                        }
136                }
137
138                @Override
139                public void handleMessage(Message<?> theMessage) {
140                        if (!(theMessage.getPayload() instanceof ResourceDeliveryMessage)) {
141                                return;
142                        }
143                        try {
144                                ResourceDeliveryMessage msg = (ResourceDeliveryMessage) theMessage.getPayload();
145                                if (myActiveSubscription.getSubscription().equals(msg.getSubscription())) {
146                                        deliver();
147                                }
148                        } catch (Exception e) {
149                                ourLog.error("Failure handling subscription payload", e);
150                                throw new MessagingException(theMessage, "Failure handling subscription payload", e);
151                        }
152                }
153
154                @Override
155                public void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) {
156                        try {
157                                theSession.sendMessage(new TextMessage("Unexpected client message: " + theMessage.getPayload()));
158                        } catch (IOException e) {
159                                handleFailure(e);
160                        }
161                }
162        }
163
164        private class InitialState implements IState {
165
166                private IIdType bindSimple(WebSocketSession theSession, String theBindString) {
167                        IdType id = new IdType(theBindString);
168
169                        WebsocketValidationResponse response = myWebsocketConnectionValidator.validate(id);
170                        if (!response.isValid()) {
171                                try {
172                                        ourLog.warn(response.getMessage());
173                                        theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), response.getMessage()));
174                                } catch (IOException e) {
175                                        handleFailure(e);
176                                }
177                                return null;
178                        }
179
180                        myState = new BoundStaticSubscriptionState(theSession, response.getActiveSubscription());
181
182                        return id;
183                }
184
185                @Override
186                public void closing() {
187                        // nothing
188                }
189
190                @Override
191                public void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) {
192                        String message = theMessage.getPayload();
193                        if (message.startsWith("bind ")) {
194                                String remaining = message.substring("bind ".length());
195
196                                IIdType subscriptionId;
197                                subscriptionId = bindSimple(theSession, remaining);
198                                if (subscriptionId == null) {
199                                        return;
200                                }
201
202                                try {
203                                        theSession.sendMessage(new TextMessage("bound " + subscriptionId.getIdPart()));
204                                } catch (IOException e) {
205                                        handleFailure(e);
206                                }
207
208                        }
209                }
210
211        }
212
213}