package com.ecep.contract.service; import com.ecep.contract.Message; import com.ecep.contract.constant.WebSocketConstant; import com.ecep.contract.ui.Tasker; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import java.io.IOException; import java.util.Locale; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @Service public class WebSocketServerTaskManager implements InitializingBean { private static final Logger logger = LoggerFactory.getLogger(WebSocketServerTaskManager.class); @Autowired private ObjectMapper objectMapper; @Autowired private ScheduledExecutorService scheduledExecutorService; private Map taskClzMap = Map.of(); @Override public void afterPropertiesSet() throws Exception { taskClzMap = Map.of( "ContractSyncTask", "com.ecep.contract.cloud.u8.ContractSyncTask", "ContractRepairTask", "com.ecep.contract.ds.contract.tasker.ContractRepairTask", "ContractVerifyTask", "com.ecep.contract.ds.contract.tasker.ContractVerifyTask", "ProjectCostImportItemsFromContractsTasker", "com.ecep.contract.ds.project.ProjectCostImportItemsFromContractsTasker" ); } public void onMessage(WebSocketSession session, JsonNode jsonNode) { // 处理 sessionId 的消息 String sessionId = jsonNode.get(WebSocketConstant.SESSION_ID_FIELD_NAME).asText(); try { handleAsSessionCallback(session, sessionId, jsonNode); } catch (Exception e) { sendError(session, sessionId, e.getMessage()); logger.warn("处理会话回调失败 (会话ID: {}): {}", sessionId, e.getMessage(), e); } } private void handleAsSessionCallback(WebSocketSession session, String sessionId, JsonNode jsonNode) { if (!jsonNode.has("type")) { throw new IllegalArgumentException("缺失 type 参数"); } String type = jsonNode.get("type").asText(); if (type.equals("createTask")) { createTask(session, sessionId, jsonNode); } } private void createTask(WebSocketSession session, String sessionId, JsonNode jsonNode) { if (!jsonNode.has(WebSocketConstant.ARGUMENTS_FIELD_NAME)) { throw new IllegalArgumentException("缺失 " + WebSocketConstant.ARGUMENTS_FIELD_NAME + " 参数"); } JsonNode argsNode = jsonNode.get(WebSocketConstant.ARGUMENTS_FIELD_NAME); String taskName = argsNode.get(0).asText(); String clzName = taskClzMap.get(taskName); if (clzName == null) { throw new IllegalArgumentException("未知的任务类型: " + taskName); } Object tasker = null; try { Class clz = Class.forName(clzName); tasker = clz.getDeclaredConstructor().newInstance(); } catch (ClassNotFoundException e) { throw new IllegalArgumentException("未知的任务类型: " + taskName + ", class: " + clzName); } catch (Exception e) { throw new IllegalArgumentException("任务类型: " + taskName + ", class: " + clzName + " 实例化失败"); } if (tasker instanceof Tasker t) { String locale = argsNode.get(1).asText(); t.setLocale(Locale.forLanguageTag(locale)); } if (tasker instanceof WebSocketServerTasker t) { t.setTitleHandler(title -> sendToSession(session, sessionId, "title", title)); t.setMessageHandler(msg -> sendMessageToSession(session, sessionId, msg)); t.setPropertyHandler((name, value) -> sendToSession(session, sessionId, "property", name, value)); t.setProgressHandler((current, total) -> sendToSession(session, sessionId, "progress", current, total)); t.init(argsNode.get(2)); } if (tasker instanceof Callable callable) { Thread.ofVirtual().start(() -> { try { sendToSession(session, sessionId, "start"); callable.call(); sendToSession(session, sessionId, "done"); } catch (Exception e) { throw new RuntimeException(e); } }); } } private boolean sendMessageToSession(WebSocketSession session, String sessionId, Message msg) { return sendToSession(session, sessionId, "message", msg.getLevel().getName(), msg.getMessage()); } private boolean sendToSession(WebSocketSession session, String sessionId, String type, Object... args) { try { String text = objectMapper.writeValueAsString(Map.of( WebSocketConstant.SESSION_ID_FIELD_NAME, sessionId, "type", type, WebSocketConstant.ARGUMENTS_FIELD_NAME, args )); session.sendMessage(new TextMessage(text)); } catch (IOException e) { // 捕获所有可能的异常,防止影响主流程 logger.error("发送错误消息失败 (会话ID: {})", session.getId(), e); } return true; } private void sendError(WebSocketSession session, String sessionId, String message) { if (session == null || !session.isOpen()) { logger.warn("尝试向已关闭的WebSocket会话发送错误消息: {}", message); return; } try { String errorMessage = objectMapper.writeValueAsString(Map.of( WebSocketConstant.SESSION_ID_FIELD_NAME, sessionId, WebSocketConstant.SUCCESS_FIELD_VALUE, false, WebSocketConstant.MESSAGE_FIELD_NAME, message )); // 检查会话状态并尝试发送错误消息 if (session.isOpen()) { session.sendMessage(new TextMessage(errorMessage)); } else { logger.warn("会话已关闭,无法发送错误消息: {}", message); } } catch (Exception e) { // 捕获所有可能的异常,防止影响主流程 logger.error("发送错误消息失败 (会话ID: {})", session.getId(), e); } } }