diff --git a/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java b/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java index e85bdba..ef5fc47 100644 --- a/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java +++ b/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java @@ -204,25 +204,25 @@ public class StuDigitalMarketingModelController { } - @ApiOperation("回归分析--线性回归") - @PostMapping("/regressionAnalysis") - @AnonymousAccess - public ResultEntity regressionAnalysis(@RequestBody RegressionAnalysisDTO analysisDTO) { - double[][] x = analysisDTO.getX(); - double[] y = analysisDTO.getY(); - - // 创建线性回归模型 - LinearRegression regression=new LinearRegression(); - regression.fit(x,y); - - - RAnalysisDTO rAnalysisDTO=new RAnalysisDTO(); - rAnalysisDTO.setIntercept(regression.getIntercept()); - rAnalysisDTO.setSlopes(regression.getSlopes()); - rAnalysisDTO.setType("线性回归"); - - return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO); - } +// @ApiOperation("回归分析--线性回归") +// @PostMapping("/regressionAnalysis") +// @AnonymousAccess +// public ResultEntity regressionAnalysis(@RequestBody RegressionAnalysisDTO analysisDTO) { +// double[][] x = analysisDTO.getX(); +// double[] y = analysisDTO.getY(); +// +// // 创建线性回归模型 +// LinearRegression regression=new LinearRegression(); +// regression.fit(x,y); +// +// +// RAnalysisDTO rAnalysisDTO=new RAnalysisDTO(); +// rAnalysisDTO.setIntercept(regression.getIntercept()); +// rAnalysisDTO.setSlopes(regression.getSlopes()); +// rAnalysisDTO.setType("线性回归"); +// +// return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO); +// } @ApiOperation("回归分析--逻辑回归") @@ -241,7 +241,6 @@ public class StuDigitalMarketingModelController { rAnalysisDTO.setIntercept(Double.parseDouble(df.format(model.getIntercept()))); rAnalysisDTO.setSlope(Double.parseDouble(df.format(model.getCoefficient()))); - rAnalysisDTO.setType("逻辑回归"); return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO); } @@ -253,18 +252,12 @@ public class StuDigitalMarketingModelController { public ResultEntity prediction(@RequestBody LogisticPredictDTO logisticPredictDTO) { DecimalFormat df = new DecimalFormat("#.0"); - if(logisticPredictDTO.getType().equals(Constant.LINEAR_REGRESSION)){ - // 创建线性回归模型 - LinearRegression regression=new LinearRegression(); - double predict = regression.predict(logisticPredictDTO.getRaX(), logisticPredictDTO.getIntercept(), logisticPredictDTO.getSlopes()); - return new ResultEntity(HttpStatus.OK,"成功",df.format(predict)); - }else { // 创建一个逻辑回归模型实例,0.1学习率 LogisticRegression model = new LogisticRegression(0.1); double predict = model.predict(logisticPredictDTO.getLX(), logisticPredictDTO.getSlope(), logisticPredictDTO.getIntercept());//预测 return new ResultEntity(HttpStatus.OK,"成功",df.format(predict)); - } + } diff --git a/src/main/java/com/sztzjy/marketing/controller/stu/StuOnAnalysisController.java b/src/main/java/com/sztzjy/marketing/controller/stu/StuOnAnalysisController.java index 3313b16..c2dc027 100644 --- a/src/main/java/com/sztzjy/marketing/controller/stu/StuOnAnalysisController.java +++ b/src/main/java/com/sztzjy/marketing/controller/stu/StuOnAnalysisController.java @@ -184,46 +184,46 @@ public class StuOnAnalysisController { } - @ApiOperation("回归分析表格下载") - @PostMapping("/rAMExport") - @AnonymousAccess - public void rAMExport(HttpServletResponse response, @RequestBody RAnalysisExportDTO rAnalysisExportDTO) { - //导出的表名 - String title = IdUtil.simpleUUID(); - - List list=new ArrayList<>(); - List listColumn=new ArrayList<>(); - list.add("回归系数值"); - - - if(rAnalysisExportDTO.getType().equals(Constant.LINEAR_REGRESSION)){ - - list.add("线性回归常数项值"); - //具体需要写入excel需要哪些字段,这些字段取自UserReward类,也就是上面的实际数据结果集的泛型 - listColumn = Arrays.asList("intercept", "slopes", "type","prediction"); - - - }else { - list.add("逻辑回归常数项值"); - //具体需要写入excel需要哪些字段,这些字段取自UserReward类,也就是上面的实际数据结果集的泛型 - listColumn = Arrays.asList("intercept", "slope", "type","prediction"); - } - list.add("回归类型"); - list.add("预测值"); - - //表中第一行表头字段 - String[] headers=list.toArray(new String[list.size()]); // 使用 list.toArray() 转换为 String数组 - - - //实际数据结果集 - List rAnalysisExportDTOS=new ArrayList<>(); - rAnalysisExportDTOS.add(rAnalysisExportDTO); - - - try { - FilePortUtil.exportExcel(response, title, headers, rAnalysisExportDTOS, listColumn); - } catch (Exception e) { - e.printStackTrace(); - } - } +// @ApiOperation("回归分析表格下载") +// @PostMapping("/rAMExport") +// @AnonymousAccess +// public void rAMExport(HttpServletResponse response, @RequestBody RAnalysisExportDTO rAnalysisExportDTO) { +// //导出的表名 +// String title = IdUtil.simpleUUID(); +// +// List list=new ArrayList<>(); +// List listColumn=new ArrayList<>(); +// list.add("回归系数值"); +// +// +// if(rAnalysisExportDTO.getType().equals(Constant.LINEAR_REGRESSION)){ +// +// list.add("线性回归常数项值"); +// //具体需要写入excel需要哪些字段,这些字段取自UserReward类,也就是上面的实际数据结果集的泛型 +// listColumn = Arrays.asList("intercept", "slopes", "type","prediction"); +// +// +// }else { +// list.add("逻辑回归常数项值"); +// //具体需要写入excel需要哪些字段,这些字段取自UserReward类,也就是上面的实际数据结果集的泛型 +// listColumn = Arrays.asList("intercept", "slope", "type","prediction"); +// } +// list.add("回归类型"); +// list.add("预测值"); +// +// //表中第一行表头字段 +// String[] headers=list.toArray(new String[list.size()]); // 使用 list.toArray() 转换为 String数组 +// +// +// //实际数据结果集 +// List rAnalysisExportDTOS=new ArrayList<>(); +// rAnalysisExportDTOS.add(rAnalysisExportDTO); +// +// +// try { +// FilePortUtil.exportExcel(response, title, headers, rAnalysisExportDTOS, listColumn); +// } catch (Exception e) { +// e.printStackTrace(); +// } +// } } diff --git a/src/main/java/com/sztzjy/marketing/controller/stu/StuPythonController.java b/src/main/java/com/sztzjy/marketing/controller/stu/StuPythonController.java index 4153e10..15093d0 100644 --- a/src/main/java/com/sztzjy/marketing/controller/stu/StuPythonController.java +++ b/src/main/java/com/sztzjy/marketing/controller/stu/StuPythonController.java @@ -5,6 +5,7 @@ import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.sztzjy.marketing.annotation.AnonymousAccess; +import com.sztzjy.marketing.entity.dto.CommentDTO; import com.sztzjy.marketing.entity.dto.SentimentAnalyDTO; import com.sztzjy.marketing.qianfan.util.Json; import com.sztzjy.marketing.util.ResultEntity; @@ -17,8 +18,7 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import java.io.*; -import java.util.List; -import java.util.Map; +import java.util.*; @RestController @@ -141,4 +141,234 @@ public class StuPythonController { } } + + @PostMapping("/comment") + @ApiOperation("评论点抽取") + @AnonymousAccess + public ResultEntity comment(@RequestBody JSONObject text) { + // System.out.println(text); + String testText = text.getString("text"); + String code = "# 导入所需依赖\n" + + "import pandas as pd\n" + + "import paddle\n" + + "from paddlenlp.transformers import SkepTokenizer, SkepModel\n" + + "from utils.utils import decoding, concate_aspect_and_opinion, format_print\n" + + "from utils import data_ext, data_cls\n" + + "from utils.model_define import SkepForTokenClassification, SkepForSequenceClassification\n" + + "\n" + + "\n" + + "# 单条文本情感分析预测函数\n" + + "def predict(input_text, ext_model, cls_model, tokenizer, ext_id2label, cls_id2label, max_length=512):\n" + + " ext_model.eval()\n" + + " cls_model.eval()\n" + + "\n" + + " # processing input text\n" + + " encoded_inputs = tokenizer(list(input_text), is_split_into_words=True, max_length=max_length,)\n" + + " input_ids = paddle.to_tensor([encoded_inputs[\"input_ids\"]])\n" + + " token_type_ids = paddle.to_tensor([encoded_inputs[\"token_type_ids\"]])\n" + + "\n" + + " # extract aspect and opinion words\n" + + " logits = ext_model(input_ids, token_type_ids=token_type_ids)\n" + + " predictions = logits.argmax(axis=2).numpy()[0]\n" + + " tag_seq = [ext_id2label[idx] for idx in predictions][1:-1]\n" + + " aps = decoding(input_text, tag_seq)\n" + + "\n" + + " # predict sentiment for aspect with cls_model\n" + + " results = []\n" + + " for ap in aps:\n" + + " aspect = ap[0]\n" + + " opinion_words = list(set(ap[1:]))\n" + + " aspect_text = concate_aspect_and_opinion(input_text, aspect, opinion_words)\n" + + " \n" + + " encoded_inputs = tokenizer(aspect_text, text_pair=input_text, max_length=max_length, return_length=True)\n" + + " input_ids = paddle.to_tensor([encoded_inputs[\"input_ids\"]])\n" + + " token_type_ids = paddle.to_tensor([encoded_inputs[\"token_type_ids\"]])\n" + + "\n" + + " logits = cls_model(input_ids, token_type_ids=token_type_ids)\n" + + " prediction = logits.argmax(axis=1).numpy()[0]\n" + + "\n" + + " result = {\"aspect\": aspect, \"opinions\": opinion_words, \"sentiment\": cls_id2label[prediction]}\n" + + " results.append(result)\n" + + "\n" + + " # print results\n" + + " format_print(results)\n" + + "\n" + + " # 返回预测结果\n" + + " return results\n" + + "\n" + + "# 批量情感分析预测函数\n" + + "def batchPredict(data, ext_model, cls_model, tokenizer, ext_id2label, cls_id2label, max_length=512):\n" + + "\n" + + " ext_model.eval()\n" + + " cls_model.eval()\n" + + "\n" + + " analysisResults = []\n" + + "\n" + + " # 针对批量文本逐条处理\n" + + " for input_text in data:\n" + + " # processing input text\n" + + " encoded_inputs = tokenizer(list(input_text), is_split_into_words=True, max_length=max_length,)\n" + + " input_ids = paddle.to_tensor([encoded_inputs[\"input_ids\"]])\n" + + " token_type_ids = paddle.to_tensor([encoded_inputs[\"token_type_ids\"]])\n" + + "\n" + + " # extract aspect and opinion words\n" + + " logits = ext_model(input_ids, token_type_ids=token_type_ids)\n" + + " predictions = logits.argmax(axis=2).numpy()[0]\n" + + " tag_seq = [ext_id2label[idx] for idx in predictions][1:-1]\n" + + " aps = decoding(input_text, tag_seq)\n" + + "\n" + + " # predict sentiment for aspect with cls_model\n" + + " results = []\n" + + " for ap in aps:\n" + + " aspect = ap[0]\n" + + " opinion_words = list(set(ap[1:]))\n" + + " aspect_text = concate_aspect_and_opinion(input_text, aspect, opinion_words)\n" + + " \n" + + " encoded_inputs = tokenizer(aspect_text, text_pair=input_text, max_length=max_length, return_length=True)\n" + + " input_ids = paddle.to_tensor([encoded_inputs[\"input_ids\"]])\n" + + " token_type_ids = paddle.to_tensor([encoded_inputs[\"token_type_ids\"]])\n" + + "\n" + + " logits = cls_model(input_ids, token_type_ids=token_type_ids)\n" + + " prediction = logits.argmax(axis=1).numpy()[0]\n" + + "\n" + + " result = {\"属性\": aspect, \"观点\": opinion_words, \"情感倾向\": cls_id2label[prediction]}\n" + + " results.append(result)\n" + + " singleResult = {\"text\": input_text, \"result\": str(results)}\n" + + " analysisResults.append(singleResult)\n" + + "\n" + + " # 返回预测结果 list形式\n" + + " return analysisResults\n" + + "\n" + + "\n" + + "label_ext_path = \"./label_ext.dict\"\n" + + "label_cls_path = \"./label_cls.dict\"\n" + + "# 加载PaddleNLP开源的基于全量数据训练好的评论观点抽取模型和属性级情感分类模型\n" + + "ext_model_path = \"./model/best_ext.pdparams\"\n" + + "cls_model_path = \"./model/best_cls.pdparams\"\n" + + "\n" + + "# load dict\n" + + "model_name = \"skep_ernie_1.0_large_ch\"\n" + + "ext_label2id, ext_id2label = data_ext.load_dict(label_ext_path)\n" + + "cls_label2id, cls_id2label = data_cls.load_dict(label_cls_path)\n" + + "tokenizer = SkepTokenizer.from_pretrained(model_name)\n" + + "print(\"label dict loaded.\")\n" + + "\n" + + "# load ext model 观点抽取模型\n" + + "ext_state_dict = paddle.load(ext_model_path)\n" + + "ext_skep = SkepModel.from_pretrained(model_name)\n" + + "ext_model = SkepForTokenClassification(ext_skep, num_classes=len(ext_label2id))\n" + + "ext_model.load_dict(ext_state_dict)\n" + + "print(\"extraction model loaded.\")\n" + + "\n" + + "# load cls model 属性级情感分析模型\n" + + "cls_state_dict = paddle.load(cls_model_path)\n" + + "cls_skep = SkepModel.from_pretrained(model_name)\n" + + "cls_model = SkepForSequenceClassification(cls_skep, num_classes=len(cls_label2id))\n" + + "cls_model.load_dict(cls_state_dict)\n" + + "print(\"classification model loaded.\")\n" + + "\n" + + "# 单条文本情感分析\n" + + "max_length = 512\n" + + "input_text = '经济实惠、动力不错、油耗低'\n" + + "predict(input_text, ext_model, cls_model, tokenizer, ext_id2label, cls_id2label, max_length=max_length)"; + + // 替换代码中的 test_text 为实际的 testText + String updatedCode = code.replace("'经济实惠、动力不错、油耗低'", "'" + testText + "'"); + + //System.out.println(updatedCode); + try { + String s = IdUtil.simpleUUID(); + String tempPythonFile = "/usr/local/tianzeProject/digitalMarketing/comment/backend/" + s + ".py"; + + File file1 = new File(tempPythonFile); + + // 确保父目录存在 + File parentDir = file1.getParentFile(); + if (!parentDir.exists()) { + System.out.println("Parent directory does not exist. Creating it."); + if (!parentDir.mkdirs()) { + System.out.println("Failed to create directories."); + return new ResultEntity<>(HttpStatus.INTERNAL_SERVER_ERROR); + } + } + + // 创建文件并写入内容 + if (!file1.exists()) { + try { + boolean fileCreated = file1.createNewFile(); + if (fileCreated) { + System.out.println("File created successfully: " + tempPythonFile); + } else { + System.out.println("File already exists: " + tempPythonFile); + } + } catch (IOException e) { + e.printStackTrace(); + return new ResultEntity<>(HttpStatus.INTERNAL_SERVER_ERROR); + } + } + + try (PrintWriter out = new PrintWriter(file1)) { + out.println(updatedCode); + } + + + // 确认 Docker 命令 + String[] command = {"docker", "exec", "pyexe", "python", tempPythonFile}; + + Process process = Runtime.getRuntime().exec(command); + + // 获取进程的输入流 + BufferedReader inputStream = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader errorStream = new BufferedReader(new InputStreamReader(process.getErrorStream())); + + // 读取 Python 代码的输出 + StringBuilder output = new StringBuilder(); + String line; + while ((line = inputStream.readLine()) != null) { + output.append(line).append("\n"); + } + + // 读取 Python 代码的错误信息 + StringBuilder errors = new StringBuilder(); + while ((line = errorStream.readLine()) != null) { + errors.append(line).append("\n"); + } + + // 等待进程执行完成 + int exitCode = process.waitFor(); + if (exitCode == 0) { + List comments = new ArrayList<>(); + + + //取到输出值转为字符串 + String strings = output.toString().trim(); + String[] split = strings.split(","); + String aspect = split[0].split(": ")[1]; + String opinionsStr = split[1].split(": ")[1]; + String sentiment = split[2].split(": ")[1]; + + + //将字符串转为集合 + opinionsStr = opinionsStr.replace("'", "").replace("[", "").replace("]", "").trim(); + List opinions = Arrays.asList(opinionsStr.split(", ")); + + + comments.add(new CommentDTO(aspect, opinions, sentiment)); + + + + // System.out.println("Python code output:\n" + output.toString()); + return new ResultEntity(HttpStatus.OK, comments); + + + } else { + // System.err.println("Error executing Python code:\n" + errors.toString()); + return new ResultEntity(HttpStatus.INTERNAL_SERVER_ERROR, errors.toString()); + } + } catch (IOException | InterruptedException e) { + e.printStackTrace(); + return new ResultEntity<>(HttpStatus.INTERNAL_SERVER_ERROR); + } + } + } diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/CommentDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/CommentDTO.java new file mode 100644 index 0000000..6c464c5 --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/entity/dto/CommentDTO.java @@ -0,0 +1,34 @@ +package com.sztzjy.marketing.entity.dto; + +import lombok.Data; + +import java.util.List; + +/** + * @author tz + * @date 2024/8/27 13:30 + */ +@Data +public class CommentDTO { + private String aspect; + + private List opinions; + + private String sentiment; + + + public CommentDTO(String aspect, List opinions, String sentiment) { + this.aspect = aspect; + this.opinions = opinions; + this.sentiment = sentiment; + } + + @Override + public String toString() { + return "CommentDTO{" + + "aspect='" + aspect + '\'' + + ", opinions=" + opinions + + ", sentiment='" + sentiment + '\'' + + '}'; + } +} diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/LogisticPredictDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/LogisticPredictDTO.java index 68e920c..0816d22 100644 --- a/src/main/java/com/sztzjy/marketing/entity/dto/LogisticPredictDTO.java +++ b/src/main/java/com/sztzjy/marketing/entity/dto/LogisticPredictDTO.java @@ -11,7 +11,7 @@ import org.ujmp.core.util.R; @EqualsAndHashCode(callSuper = true) @Data public class LogisticPredictDTO extends RAnalysisDTO { - private double[] raX; + private double raX; private double lX; diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/RAnalysisDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/RAnalysisDTO.java index b491263..272b5b1 100644 --- a/src/main/java/com/sztzjy/marketing/entity/dto/RAnalysisDTO.java +++ b/src/main/java/com/sztzjy/marketing/entity/dto/RAnalysisDTO.java @@ -11,10 +11,6 @@ import lombok.Data; public class RAnalysisDTO { @ApiModelProperty("回归系数值") private double intercept; - @ApiModelProperty("线性回归常数项值") - private double[] slopes; @ApiModelProperty("逻辑回归常数项值") private double slope; - @ApiModelProperty("线性回归/逻辑回归") - private String type; } diff --git a/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java b/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java index 8ad6ada..6c4f948 100644 --- a/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java +++ b/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java @@ -314,7 +314,6 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM } else if(tableName.equals(Constant.YONGHUXIAOSHOUNLB)){ // list=indicatorsMapper.getYHDLHYB(); - list.add("id"); list.add("sales_volume"); list.add("sales_forehead"); @@ -405,7 +404,10 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM Collections.rotate(attributes,-1); } } - return new ResultEntity(HttpStatus.OK,attributes); + // 过滤掉包含逗号的字符串值,并将负数转换为正数(对于数字类型) + List> filteredAttributes = filterAndConvertAttributes(attributes); + + return new ResultEntity(HttpStatus.OK,filteredAttributes); } List fieldList = analyzeDataDTO.getFieldList(); @@ -586,4 +588,37 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM } return filteredList; } + + + // 过滤掉包含逗号的字符串值,并将负数转换为正数(对于数字类型) + public static List> filterAndConvertAttributes(List> attributes) { + List> filteredAndConvertedList = new ArrayList<>(); + + for (Map map : attributes) { + Map convertedMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Object value = entry.getValue(); + + // 检查值是否为字符串且包含逗号 + if (value instanceof String) { + // 去除字符串中的逗号 + String strValue = (String) value; + value = strValue.replace(",", "").replace("-",""); + } + + // 检查值是否为数字类型且为负数 + if (value instanceof Number && ((Number) value).doubleValue() < 0) { + // 转换为正数 + value = Math.abs(((Number) value).doubleValue()); + // 注意:这里为了保持类型一致,我们将其转换回原始类型(如果需要) + // 这里简化为double类型,实际中可能需要更复杂的处理来保持原始类型 + } + + convertedMap.put(entry.getKey(), value); + } + filteredAndConvertedList.add(convertedMap); + } + + return filteredAndConvertedList; + } }