spring-websocket实现聊天室功能

最近看到有些人的博客中有聊天室的功能所以我也在我博客中写了一个,不过他们用的是java原生的,这里我使用了spring封装的spring-websocket

Spring-WebSocket配置

我们第一步要先配置一下websocket 的基本信息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/**
* @Author: ZVerify
* @Description: TODO WebSocket相关配置
* @DateTime: 2022/9/6 14:21
**/
@Configuration
@EnableWebSocket
public class ZVerifyWebSocketConfig implements WebSocketConfigurer {

// 注册 WebSocket 处理器
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) {
webSocketHandlerRegistry
// WebSocket 连接处理器
.addHandler(new ZVerifyWebSocketHandler(), "/ws-connect")
// WebSocket 拦截器
.addInterceptors(new ZVerifyWebSocketInterceptor())
// 允许跨域
.setAllowedOrigins("*");
}

}

其中连接处理器和拦截器是我们自己定义的

"/ws-connect"就是我们的路径

因为想要建立连接首先要通过我们的拦截器所以按照逻辑来写拦截器

前置拦截器

这个前置拦截器一般我们会做安全的校验和一系列处理,这里我就简单了写了一下,这里要做安全校验是因为我们定义的websocket并没有托管给我所使用的安全框架去验证用户,所以在这里要简单校验一下,

前置处理器的创建要去实现HandshakeInterceptor接口然后重写beforeHandshake,afterHandshake,两个方法,beforeHandshake是用做握手前置校验的,afterHandshake是做握手后置校验的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/**
* @Author: ZVerify
* @Description: TODO WebSocket 前置拦截器
* @DateTime: 2022/9/6 14:19
**/
@Configuration
public class ZVerifyWebSocketInterceptor implements HandshakeInterceptor {
// 握手之前触发 (return true 才会握手成功 )
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler,
Map<String, Object> attr) {

System.out.println("---- 握手之前触发 " + StpUtil.getTokenValue());

// 未登录情况下拒绝握手
if(!StpUtil.isLogin()) {
System.out.println("---- 未授权客户端,连接失败");
return false;
}

// 标记 userId,握手成功
attr.put("userId", StpUtil.getLoginIdAsLong());

return true;
}

// 握手之后触发
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
Exception exception) {
System.out.println("---- 握手之后触发 ");
}

}

连接处理器

这里是我们的主要处理器,基本上所有重要业务都在这里

首先创建一个自己的ZVerifyWebSocketHandler然后再去继承TextWebSocketHandler我们可以定制的去实现里边的方法,这里我就按照我自己的博客需求进行重写了,如果需要可以自行扩展。

image-20220908205224501

重要属性

image-20220908205537147

这个是用来存放我们当前在线的人的信息的,用于广播和人数统计还有私信

进入聊天成功的逻辑

首先重写afterConnectionEstablished()方法这个方法是在连接开启的时候触发的,也就是我握手成功之后,因为是聊天室所以功能防QQ做了,在登录之后会看到当前博客群聊中的在线人数,然后加载聊天记录。这一些简单的过程

image-20220908211059504

  1. 首先要从session中取到当前连接的用户id,这里我要解释一下这个userId是从哪来的,是在我的握手之前触发的那个beforeHandshake()中写的项目中用的安全框架为Sa-Token,不熟悉的请自行查阅,拿到用户id之后将当前用户的webSocketSession存放到map中

  2. 更新当前的在线人数,这个处理是比较简单的image-20220908212420367

    就是获取一下map的大小就是当前在线人数,然后发送广播消息,这里说一下广播消息其实很简单就是将map中的webSocketSession都取出来然后挨个发送消息注意这里要加一个锁因为不加锁的话可能会导致消息前后异常

  3. 加载历史记录也很平常就是将我们聊天记录存到数据库中,然后将其xxx小时的消息加载出来,然后想当前登录用户发送这里我使用的是历史12小时image-20220908213511483

收到消息之后处理逻辑

处理收到消息逻辑是handleTextMessage()方法里边有两个参数一个是发送消息的session,一个是包装的消息对象TextMessage,首先先带大家看一下TextMessage是个什么东西,我们在通过webSocketSession发送消息的时候可以发送多种对象image-20220908214150426

这里我使用了TextMessage,所以就讲一下这里我们在创建TextMessage对象的时候传入参数通过源码可以知道我可以传入一个可读的char值序列然后会将其转换成字符串调用抽象类的构造方法image-20220908214411389

第二个参数的意义是这是否是作为一系列部分消息发送的消息的最后一部分。到这里可以知道我们发送的消息就是抽象类AbstractWebSocketMessage中的payload属性,所以在这里我买可以通过这个入参拿到数据,然后根据其数据的第一个参数,也就是当前的类型去进行对应的逻辑处理,这里就没什么难点了

连接关闭

image-20220908214928823

连接关闭的时候讲当前的用户session从map中remove掉就好如需扩展请自己进行逻辑的修改

源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
package com.zang.blogz.handler;

import cn.hutool.core.date.DateUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.fastjson.JSON;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.zang.blogz.dto.ChatRecordDTO;
import com.zang.blogz.dto.RecallMessageDTO;
import com.zang.blogz.dto.WebsocketMessageDTO;
import com.zang.blogz.enmus.ChatTypeEnum;
import com.zang.blogz.enmus.FilePathEnum;
import com.zang.blogz.entity.ChatRecord;
import com.zang.blogz.entity.UserInfo;
import com.zang.blogz.model.input.ro.VoiceRO;
import com.zang.blogz.service.ChatRecordService;
import com.zang.blogz.service.UserInfoService;
import com.zang.blogz.steam.optional.Opp;
import com.zang.blogz.strategy.context.UploadStrategyContext;
import com.zang.blogz.utils.BeanCopyUtils;
import com.zang.blogz.utils.IpUtil;
import lombok.Data;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import javax.websocket.server.ServerEndpoint;

import java.io.IOException;
import java.net.InetAddress;
import java.util.Collection;
import java.util.Date;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;


/**
* @Author: ZVerify
* @Description: websocket服务
* @DateTime: 2022/9/6 14:03
**/
@Data
@Service
@ServerEndpoint(value = "/ws-connect")
public class ZVerifyWebSocketHandler extends TextWebSocketHandler {


private static ChatRecordService chatRecordService;

@Autowired
public void setChatRecordDao(ChatRecordService chatRecordService) {
ZVerifyWebSocketHandler.chatRecordService = chatRecordService;
}

private static UserInfoService userInfoService;

@Autowired
public void setUserInfoService(UserInfoService userInfoService) {
ZVerifyWebSocketHandler.userInfoService = userInfoService;
}

private static UploadStrategyContext uploadStrategyContext;

@Autowired
public void setUploadStrategyContext(UploadStrategyContext uploadStrategyContext) {
ZVerifyWebSocketHandler.uploadStrategyContext = uploadStrategyContext;
}
/**
* 固定前缀
*/
public static String HEADER_NAME = "X-Real-IP";

/**
* 存放Session集合,方便推送消息
*/
private static ConcurrentHashMap<String, WebSocketSession> webSocketSessionMaps = new ConcurrentHashMap<>();



// 监听:连接开启
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {

// put到集合,方便后续操作
String userId = session.getAttributes().get("userId").toString();
webSocketSessionMaps.put(HEADER_NAME + userId, session);
// 更新在线人数
updateOnlineCount();

// 加载历史聊天记录
ChatRecordDTO chatRecordDTO = listChartRecords(session);

// 发送消息
WebsocketMessageDTO messageDTO = WebsocketMessageDTO.builder()
.type(ChatTypeEnum.HISTORY_RECORD.getType())
.data(chatRecordDTO)
.build();
synchronized (session) {
session.sendMessage(new TextMessage(JSON.toJSONString(messageDTO)));
}
// 给个提示
String tips = "Web-Socket 连接成功,sid=" + session.getId() + ",userId=" + userId;
System.out.println(tips);

}

/**
* 加载历史聊天记录
*
* @param session session
* @return 加载历史聊天记录
*/
private ChatRecordDTO listChartRecords(WebSocketSession session) {

String ipAddress = session.getAcceptedProtocol();

LambdaQueryWrapper<ChatRecord> queryWrapper = new LambdaQueryWrapper<>();

queryWrapper.ge(ChatRecord::getCreateTime, DateUtil.offsetHour(new Date(), -12));

return ChatRecordDTO.builder()
.chatRecordList(chatRecordService.list(queryWrapper))
.ipAddress(ipAddress)
.ipSource(IpUtil.getIpSource(ipAddress))
.build();
}

private void updateOnlineCount() throws IOException {

// 获取当前在线人数
WebsocketMessageDTO messageDTO = WebsocketMessageDTO.builder()
.type(ChatTypeEnum.ONLINE_COUNT.getType())
.data(webSocketSessionMaps.size())
.build();
// 广播消息
broadcastMessage(messageDTO);
}

// 监听:连接关闭
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status){
// 从集合移除
String userId = session.getAttributes().get("userId").toString();
webSocketSessionMaps.remove(HEADER_NAME + userId);

}

// 收到消息
@Override
public void handleTextMessage(WebSocketSession session, TextMessage message) throws IOException {

String ipAddress = null;
WebsocketMessageDTO messageDTO = JSONUtil.toBean(message.getPayload(), WebsocketMessageDTO.class, false);
switch (Objects.requireNonNull(ChatTypeEnum.getChatType(messageDTO.getType()))) {
case SEND_MESSAGE:

String data = String.valueOf(messageDTO.getData()) ;
InetAddress address = Objects.requireNonNull(session.getLocalAddress()).getAddress();
if (Opp.of(address).isNonNull()){

ipAddress = address.getHostAddress();
}


String userId = session.getAttributes().get("userId").toString();
UserInfo byId = userInfoService.getById(Integer.valueOf(userId));

// 发送消息
ChatRecord chatRecord = new ChatRecord();

chatRecord.setContent(data);
chatRecord.setType(messageDTO.getType());
chatRecord.setAvatar(byId.getAvatar());
chatRecord.setNickname(byId.getNickname());
chatRecord.setUserId(byId.getId());
chatRecord.setIpAddress(ipAddress);
String ipSource = IpUtil.getIpSource(ipAddress);
chatRecord.setIpSource(ipSource);
chatRecordService.save(chatRecord);

messageDTO.setData(chatRecord);
// 广播消息
broadcastMessage(messageDTO);
break;
case RECALL_MESSAGE:
// 撤回消息
RecallMessageDTO recallMessage = JSON.parseObject(JSON.toJSONString(messageDTO.getData()), RecallMessageDTO.class);
// 删除记录
chatRecordService.removeById(recallMessage.getId());
// 广播消息
broadcastMessage(messageDTO);
break;
case HEART_BEAT:
// 心跳消息
messageDTO.setData("pong");
session.sendMessage(new TextMessage((JSON.toJSONString(messageDTO))));

default:
break;
}
}

// -----------

// 向指定客户端推送消息
public static void sendMessage(WebSocketSession session, String message) {
try {
System.out.println("向sid为:" + session.getId() + ",发送:" + message);
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
throw new RuntimeException(e);
}
}

// 向指定用户推送消息
public static void sendMessage(long userId, String message) {
WebSocketSession session = webSocketSessionMaps.get(HEADER_NAME + userId);
if(session != null) {
sendMessage(session, message);
}
}

/**
* 广播消息
*
* @param messageDTO 消息dto
* @throws IOException io异常
*/
private void broadcastMessage(WebsocketMessageDTO messageDTO) throws IOException {

Collection<WebSocketSession> sessions = webSocketSessionMaps.values();

for (WebSocketSession webSocketService : sessions) {
synchronized (webSocketService){
TextMessage textMessage = new TextMessage(JSON.toJSONString(messageDTO));
webSocketService.sendMessage(textMessage);
}

}
}

/**
* 发送语音
*
* @param voiceRO 语音路径
*/
public void sendVoice(VoiceRO voiceRO) {
// 上传语音文件
String content = uploadStrategyContext.executeUploadStrategy(voiceRO.getFile(), FilePathEnum.VOICE.getPath());
voiceRO.setContent(content);
// 保存记录
ChatRecord chatRecord = BeanCopyUtils.copyObject(voiceRO, ChatRecord.class);
chatRecordService.save(chatRecord);
// 发送消息
WebsocketMessageDTO messageDTO = WebsocketMessageDTO.builder()
.type(ChatTypeEnum.VOICE_MESSAGE.getType())
.data(chatRecord)
.build();
// 广播消息
try {
broadcastMessage(messageDTO);
} catch (IOException e) {
e.printStackTrace();
}
}



}