diff --git a/src/main/java/com/sztzjy/marketing/config/security/WebConfigurerAdapter.java b/src/main/java/com/sztzjy/marketing/config/security/WebConfigurerAdapter.java index ea91b80..4bb41f8 100644 --- a/src/main/java/com/sztzjy/marketing/config/security/WebConfigurerAdapter.java +++ b/src/main/java/com/sztzjy/marketing/config/security/WebConfigurerAdapter.java @@ -6,11 +6,15 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.support.ReloadableResourceBundleMessageSource; +import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.core.task.TaskExecutor; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.web.cors.CorsConfiguration; import org.springframework.web.cors.UrlBasedCorsConfigurationSource; import org.springframework.web.filter.CorsFilter; +import org.springframework.web.servlet.config.annotation.AsyncSupportConfigurer; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.config.annotation.ResourceHandlerRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; @@ -33,6 +37,24 @@ public class WebConfigurerAdapter implements WebMvcConfigurer { @Value("${file.path}") private String filePath; + + @Bean + public AsyncTaskExecutor asyncTaskExecutor() { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + executor.setCorePoolSize(10); // 设置核心线程数 + executor.setMaxPoolSize(50); // 设置最大线程数 + executor.setQueueCapacity(100); // 设置队列容量 + executor.setThreadNamePrefix("async-executor-"); + executor.initialize(); + return executor; + } + + @Override + public void configureAsyncSupport(AsyncSupportConfigurer configurer) { + configurer.setTaskExecutor(asyncTaskExecutor()); + configurer.setDefaultTimeout(30000); // 设置默认超时时间 + } + @Bean public IFileUtil getFileUtil() { if (fileType.equals("local")) { diff --git a/src/main/java/com/sztzjy/marketing/controller/stu/QianFanBigModuleController.java b/src/main/java/com/sztzjy/marketing/controller/stu/QianFanBigModuleController.java index 041b53c..70efdf2 100644 --- a/src/main/java/com/sztzjy/marketing/controller/stu/QianFanBigModuleController.java +++ b/src/main/java/com/sztzjy/marketing/controller/stu/QianFanBigModuleController.java @@ -4,6 +4,7 @@ import com.sztzjy.marketing.annotation.AnonymousAccess; import com.sztzjy.marketing.entity.dto.StuCreateArticleDTO; import com.sztzjy.marketing.entity.dto.StuCreateImgDTO; import com.sztzjy.marketing.qianfan.Qianfan; +import com.sztzjy.marketing.qianfan.model.chat.Message; import com.sztzjy.marketing.service.QianFanBigModuleService; import com.sztzjy.marketing.util.ResultDataEntity; import com.sztzjy.marketing.util.ResultEntity; @@ -15,6 +16,11 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpStatus; import org.springframework.web.bind.annotation.*; +import org.yaml.snakeyaml.reader.StreamReader; +import reactor.core.publisher.Flux; + +import javax.annotation.Resource; +import java.util.List; /** * @author 17803 @@ -43,18 +49,28 @@ public class QianFanBigModuleController { } +// @ApiOperation("AI文章") +// @AnonymousAccess +// @PostMapping("/createArticleByAi") +// public ResultEntity createArticleByAi(@RequestBody StuCreateArticleDTO stuCreateArticleDTO) { +// +// +// return qianFanBigModuleService.createArticleByAi(stuCreateArticleDTO); +// +// } + + @ApiOperation("AI文章") @AnonymousAccess @PostMapping("/createArticleByAi") - public ResultEntity createArticleByAi(@RequestBody StuCreateArticleDTO stuCreateArticleDTO) { + public Flux createArticleByAi(@RequestBody List messageList) { - return qianFanBigModuleService.createArticleByAi(stuCreateArticleDTO); - - } + return qianFanBigModuleService.createArticleByMessage(messageList); + } } diff --git a/src/main/java/com/sztzjy/marketing/service/QianFanBigModuleService.java b/src/main/java/com/sztzjy/marketing/service/QianFanBigModuleService.java index 439e7b6..f6aa2ca 100644 --- a/src/main/java/com/sztzjy/marketing/service/QianFanBigModuleService.java +++ b/src/main/java/com/sztzjy/marketing/service/QianFanBigModuleService.java @@ -2,7 +2,11 @@ package com.sztzjy.marketing.service; import com.sztzjy.marketing.entity.dto.StuCreateArticleDTO; import com.sztzjy.marketing.entity.dto.StuCreateImgDTO; +import com.sztzjy.marketing.qianfan.model.chat.Message; import com.sztzjy.marketing.util.ResultEntity; +import reactor.core.publisher.Flux; + +import java.util.List; /** * @author 17803 @@ -17,11 +21,20 @@ public interface QianFanBigModuleService { */ ResultEntity createImgByAi(StuCreateImgDTO stuCreateImgDTO); +// /** +// * AI文章 +// * @param stuCreateArticleDTO +// * @return +// */ +// +// ResultEntity createArticleByAi(StuCreateArticleDTO stuCreateArticleDTO); + + /** * AI文章 - * @param stuCreateArticleDTO + * @param messageList * @return */ - ResultEntity createArticleByAi(StuCreateArticleDTO stuCreateArticleDTO); + Flux createArticleByMessage(List messageList); } diff --git a/src/main/java/com/sztzjy/marketing/service/impl/QianFanBigModuleServiceImpl.java b/src/main/java/com/sztzjy/marketing/service/impl/QianFanBigModuleServiceImpl.java index 25c7fac..f3e7122 100644 --- a/src/main/java/com/sztzjy/marketing/service/impl/QianFanBigModuleServiceImpl.java +++ b/src/main/java/com/sztzjy/marketing/service/impl/QianFanBigModuleServiceImpl.java @@ -3,6 +3,7 @@ package com.sztzjy.marketing.service.impl;/** * @date 2024-06-03 15:26 */ +import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.sztzjy.marketing.entity.dto.ReqChatMessage; import com.sztzjy.marketing.entity.dto.StuCreateArticleDTO; @@ -19,11 +20,16 @@ import lombok.extern.slf4j.Slf4j; import okhttp3.*; import okio.BufferedSource; import okio.Okio; +import org.apache.commons.lang3.StringUtils; import org.apache.http.client.methods.HttpPost; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.task.AsyncTaskExecutor; import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; +import java.io.BufferedReader; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; @@ -40,7 +46,8 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { @Value("${bigModule.secretKey}") public String secretKey; - + @Autowired + private AsyncTaskExecutor asyncTaskExecutor; @Override public ResultEntity createImgByAi(StuCreateImgDTO stuCreateImgDTO) { @@ -52,10 +59,7 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { throw new RuntimeException(e); } - chat(accesstoken); - - - +// chat(accesstoken); Qianfan qianfan = new Qianfan("OAuth", accessKey, secretKey); @@ -69,8 +73,39 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { } //AI文章 +// @Override +// public ResultEntity createArticleByAi(StuCreateArticleDTO stuCreateArticleDTO) { +// +// +// +// String accesstoken = null; +// try { +// accesstoken = getAccesstoken(); +// } catch (IOException e) { +// throw new RuntimeException(e); +// } +// +//// chat(accesstoken); +// +// ChatResponse response = new Qianfan("OAuth",accessKey, secretKey).chatCompletion() +// .model("ERNIE-3.5-8K") // 使用model指定预置模型 +// //.endpoint("ERNIE-Bot") // 也可以使用endpoint指定任意模型 (二选一) +// .addMessage("user", stuCreateArticleDTO.getContent()) // 添加用户消息 (此方法可以调用多次,以实现多轮对话的消息传递) +// .temperature(0.7) // 自定义超参数 +// .execute(); // 发起请求 +// +// return new ResultEntity<>(response.getResult()); +// +// } + + /** + * AI文章 + * @param messageList + * @return + */ + @Override - public ResultEntity createArticleByAi(StuCreateArticleDTO stuCreateArticleDTO) { + public Flux createArticleByMessage(List messageList) { @@ -81,17 +116,9 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { throw new RuntimeException(e); } - chat(accesstoken); - - ChatResponse response = new Qianfan("OAuth",accessKey, secretKey).chatCompletion() - .model("ERNIE-3.5-8K") // 使用model指定预置模型 - //.endpoint("ERNIE-Bot") // 也可以使用endpoint指定任意模型 (二选一) - .addMessage("user", stuCreateArticleDTO.getContent()) // 添加用户消息 (此方法可以调用多次,以实现多轮对话的消息传递) - .temperature(0.7) // 自定义超参数 - .execute(); // 发起请求 - return new ResultEntity<>(response.getResult()); + return (chat(accesstoken,messageList)); } //获取accessToken @@ -120,30 +147,29 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { - public void chat(String accesstoken) { + public Flux chat(String accesstoken,List messageList) { try { //获取RequestBody - RequestBody funcCallReqBody = getRequestBody(); + RequestBody funcCallReqBody = getRequestBody(messageList); //发送请求 - sendRequest(funcCallReqBody,accesstoken); + return sendRequest(funcCallReqBody,accesstoken); } catch (Exception e) { log.error("chat执行失败", e); } + return null; } /** * 获取RequestBody * @return */ - protected static RequestBody getRequestBody() { - List messages = new ArrayList<>(); - // Message chatMsg = new Message("user", "content"); - Message chatMsg = new Message(); - chatMsg.setRole("user"); - chatMsg.setContent("您好"); - messages.add(chatMsg); + protected static RequestBody getRequestBody(List messages) { +// List messages = new ArrayList<>(); +// // Message chatMsg = new Message("user", "content"); +// +// messages.add(message); //获取组装好的请求 ReqChatMessage reqMessage = new ReqChatMessage(); @@ -164,29 +190,75 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { * @throws IOException * @returnline */ - protected static String sendRequest(RequestBody body, String accessToken) { - String respResult = ""; + protected Flux sendRequest(RequestBody body, String accessToken) throws IOException { + + Request request = new Request.Builder().url("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0329?access_token=" + accessToken).method("POST", body).addHeader("Content-Type", "application/json").build(); + OkHttpClient client = new OkHttpClient(); + + return Flux.create(sink -> { + asyncTaskExecutor.execute(() -> { // 使用 AsyncTaskExecutor 执行异步任务 + Response response = null; + ResponseBody responseBody = null; + BufferedSource reader = null; + try { + response = client.newCall(request).execute(); + responseBody = response.body(); + if (responseBody != null) { + reader = Okio.buffer(responseBody.source()); + String line; + while ((line = reader.readUtf8Line()) != null) { + line = line.replace("data: ", ""); + JSONObject json = JSON.parseObject(line); + if (json != null && StringUtils.isNotBlank((String) json.get("result"))) { + sink.next((String) json.get("result")); // 返回答案 + System.out.println("Received data: " + json.get("result")); // 打印接收到的数据 + if (Boolean.TRUE.equals(json.get("is_end"))) { + sink.complete(); // 结束响应 + break; + } + } else if (json != null && Boolean.TRUE.equals(json.get("is_end"))) { + sink.complete(); // 结束响应 + break; + } else { + sink.next("\u200B"); // 返回空白字符 + } + } + sink.complete(); // 完成响应 + } else { + sink.complete(); // 无响应体,结束响应 + } + } catch (IOException e) { + sink.error(e); + } finally { + try { + if (reader != null) reader.close(); + if (responseBody != null) responseBody.close(); + } catch (IOException e) { + // handle exception + } + } + }); + }); + +//=========================== - OkHttpClient client = new OkHttpClient(); - - try (Response response = client.newCall(request).execute()) { - if (!response.isSuccessful()) { - throw new IOException("Unexpected code " + response); - } - BufferedSource source = Okio.buffer(response.body().source()); - String line = ""; - while ((line = source.readUtf8Line()) != null) { - //流式打印出来 - System.out.println(line); - } - } catch (Exception e) { - log.error("sendRequest请求执行失败", e); - } - - return respResult; - } +// +// try (Response response = client.newCall(request).execute()) { +// if (!response.isSuccessful()) { +// throw new IOException("Unexpected code " + response); +// } +// BufferedSource source = Okio.buffer(response.body().source()); +// String line = ""; +// while ((line = source.readUtf8Line()) != null) { +// //流式打印出来 +// System.out.println(line); +// return line; +// } +// } catch (Exception e) { +// log.error("sendRequest请求执行失败", e); +// } // public static void main(String[] args) throws IOException { @@ -195,6 +267,7 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { // // // String secretKey = "OrY5ZaSdv3mKtqzfh83x5DShCfIF4gi1"; +// // String requestBody = "grant_type=" + URLEncoder.encode("client_credentials", "UTF8") + // "&client_id=" + URLEncoder.encode(accessKey, "UTF8")+ // "&client_secret=" + URLEncoder.encode(secretKey, "UTF8"); @@ -215,11 +288,14 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { // System.out.println(accessToken); // // +// Message message = new Message(); +// message.setRole("user"); +// message.setContent("今天天气怎么样"); // // // -// //获取RequestBody -// RequestBody funcCallReqBody = getRequestBody(); +// //获取RequestBody +// RequestBody funcCallReqBody = getRequestBody(message); // // //发送请求 // sendRequest(funcCallReqBody,accessToken); @@ -228,4 +304,67 @@ public class QianFanBigModuleServiceImpl implements QianFanBigModuleService { // // } + //================== + + // String respResult = ""; +// Request request = new Request.Builder().url("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0329?access_token=" + +// accessToken).method("POST", body).addHeader("Content-Type", "application/json").build(); +// +// +// +// +// OkHttpClient client = new OkHttpClient(); +// +// +// Response response = client.newCall(request).execute(); +// ResponseBody responseBody = response.body(); +// if (responseBody != null) { +// BufferedSource reader = Okio.buffer(response.body().source()); +// return Flux.generate(sink -> { // 流式响应式接口 +// try { +// String line = reader.readUtf8Line(); +// if (null != line) { +// line = line.replace("data: ", ""); +// JSONObject json = JSON.parseObject(line); +// if (null != json && StringUtils.isNotBlank((String) json.get("result"))) { +// sink.next((String) json.get("result")); // 返回答案 +// if ((Boolean) json.get("is_end")) { +// +// sink.complete(); // 结束响应 +// } +// } else if (null != json && (Boolean) json.get("is_end")) { +// +// sink.complete(); // 结束响应 +// } else { +// sink.next("\u200B"); // 返回空白字符 +// } +// } +// } catch (IOException e) { +// sink.error(e); +// } +// }).doFinally(signalType -> { +// try { +// reader.close(); +// responseBody.close(); +// } catch (IOException e) { +// // handle exception +// } +// }); +// } +// +// +// +// +// +// +// +// +// +// +// return Flux.empty(); +//==== + + } + + } diff --git a/src/main/java/com/sztzjy/marketing/util/ResultDataEntity.java b/src/main/java/com/sztzjy/marketing/util/ResultDataEntity.java index 1cc8f0c..04e4b5b 100644 --- a/src/main/java/com/sztzjy/marketing/util/ResultDataEntity.java +++ b/src/main/java/com/sztzjy/marketing/util/ResultDataEntity.java @@ -3,6 +3,7 @@ package com.sztzjy.marketing.util; import io.swagger.annotations.ApiModel; import lombok.Getter; import org.springframework.http.HttpStatus; +import reactor.core.publisher.Flux; /** * 出参对象统一封装 @@ -24,6 +25,9 @@ public class ResultDataEntity { // @JsonInclude(JsonInclude.Include.NON_NULL) private T data = null; + + private Flux flow; + public ResultDataEntity(T data) { this.code = HttpStatus.OK.value(); this.data = data; @@ -34,6 +38,11 @@ public class ResultDataEntity { this.data = data; } + public ResultDataEntity(HttpStatus status, Flux data) { + this.code = status.value(); + this.flow = data; + } + public ResultDataEntity(HttpStatus status, String msg, T data) { this.code = status.value(); this.msg = msg; diff --git a/src/main/java/com/sztzjy/marketing/util/ResultEntity.java b/src/main/java/com/sztzjy/marketing/util/ResultEntity.java index 0c466d9..49ac07a 100644 --- a/src/main/java/com/sztzjy/marketing/util/ResultEntity.java +++ b/src/main/java/com/sztzjy/marketing/util/ResultEntity.java @@ -4,6 +4,7 @@ import io.swagger.annotations.ApiModel; import lombok.Getter; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; +import reactor.core.publisher.Flux; /** * 对ResponseEntity做了简单包装 @@ -35,4 +36,8 @@ public class ResultEntity extends ResponseEntity> { public ResultEntity(HttpStatus status) { super(new ResultDataEntity<>(status), status); } + + public ResultEntity(HttpStatus status, Flux flow) { + super(new ResultDataEntity<>(status, flow), status); + } }