数字营销实训算法第五轮修改

master
@t2652009480 7 months ago
parent fcd2c83048
commit c0d3006986

@ -204,25 +204,25 @@ public class StuDigitalMarketingModelController {
}
@ApiOperation("回归分析--线性回归")
@PostMapping("/regressionAnalysis")
@AnonymousAccess
public ResultEntity regressionAnalysis(@RequestBody RegressionAnalysisDTO analysisDTO) {
double[][] x = analysisDTO.getX();
double[] y = analysisDTO.getY();
// 创建线性回归模型
LinearRegression regression=new LinearRegression();
regression.fit(x,y);
RAnalysisDTO rAnalysisDTO=new RAnalysisDTO();
rAnalysisDTO.setIntercept(regression.getIntercept());
rAnalysisDTO.setSlopes(regression.getSlopes());
rAnalysisDTO.setType("线性回归");
return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO);
}
// @ApiOperation("回归分析--线性回归")
// @PostMapping("/regressionAnalysis")
// @AnonymousAccess
// public ResultEntity regressionAnalysis(@RequestBody RegressionAnalysisDTO analysisDTO) {
// double[][] x = analysisDTO.getX();
// double[] y = analysisDTO.getY();
//
// // 创建线性回归模型
// LinearRegression regression=new LinearRegression();
// regression.fit(x,y);
//
//
// RAnalysisDTO rAnalysisDTO=new RAnalysisDTO();
// rAnalysisDTO.setIntercept(regression.getIntercept());
// rAnalysisDTO.setSlopes(regression.getSlopes());
// rAnalysisDTO.setType("线性回归");
//
// return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO);
// }
@ApiOperation("回归分析--逻辑回归")
@ -241,7 +241,6 @@ public class StuDigitalMarketingModelController {
rAnalysisDTO.setIntercept(Double.parseDouble(df.format(model.getIntercept())));
rAnalysisDTO.setSlope(Double.parseDouble(df.format(model.getCoefficient())));
rAnalysisDTO.setType("逻辑回归");
return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO);
}
@ -253,18 +252,12 @@ public class StuDigitalMarketingModelController {
public ResultEntity prediction(@RequestBody LogisticPredictDTO logisticPredictDTO) {
DecimalFormat df = new DecimalFormat("#.0");
if(logisticPredictDTO.getType().equals(Constant.LINEAR_REGRESSION)){
// 创建线性回归模型
LinearRegression regression=new LinearRegression();
double predict = regression.predict(logisticPredictDTO.getRaX(), logisticPredictDTO.getIntercept(), logisticPredictDTO.getSlopes());
return new ResultEntity(HttpStatus.OK,"成功",df.format(predict));
}else {
// 创建一个逻辑回归模型实例,0.1学习率
LogisticRegression model = new LogisticRegression(0.1);
double predict = model.predict(logisticPredictDTO.getLX(), logisticPredictDTO.getSlope(), logisticPredictDTO.getIntercept());//预测
return new ResultEntity(HttpStatus.OK,"成功",df.format(predict));
}
}

@ -184,46 +184,46 @@ public class StuOnAnalysisController {
}
@ApiOperation("回归分析表格下载")
@PostMapping("/rAMExport")
@AnonymousAccess
public void rAMExport(HttpServletResponse response, @RequestBody RAnalysisExportDTO rAnalysisExportDTO) {
//导出的表名
String title = IdUtil.simpleUUID();
List<String> list=new ArrayList<>();
List<String> listColumn=new ArrayList<>();
list.add("回归系数值");
if(rAnalysisExportDTO.getType().equals(Constant.LINEAR_REGRESSION)){
list.add("线性回归常数项值");
//具体需要写入excel需要哪些字段这些字段取自UserReward类也就是上面的实际数据结果集的泛型
listColumn = Arrays.asList("intercept", "slopes", "type","prediction");
}else {
list.add("逻辑回归常数项值");
//具体需要写入excel需要哪些字段这些字段取自UserReward类也就是上面的实际数据结果集的泛型
listColumn = Arrays.asList("intercept", "slope", "type","prediction");
}
list.add("回归类型");
list.add("预测值");
//表中第一行表头字段
String[] headers=list.toArray(new String[list.size()]); // 使用 list.toArray() 转换为 String数组
//实际数据结果集
List<RAnalysisExportDTO> rAnalysisExportDTOS=new ArrayList<>();
rAnalysisExportDTOS.add(rAnalysisExportDTO);
try {
FilePortUtil.exportExcel(response, title, headers, rAnalysisExportDTOS, listColumn);
} catch (Exception e) {
e.printStackTrace();
}
}
// @ApiOperation("回归分析表格下载")
// @PostMapping("/rAMExport")
// @AnonymousAccess
// public void rAMExport(HttpServletResponse response, @RequestBody RAnalysisExportDTO rAnalysisExportDTO) {
// //导出的表名
// String title = IdUtil.simpleUUID();
//
// List<String> list=new ArrayList<>();
// List<String> listColumn=new ArrayList<>();
// list.add("回归系数值");
//
//
// if(rAnalysisExportDTO.getType().equals(Constant.LINEAR_REGRESSION)){
//
// list.add("线性回归常数项值");
// //具体需要写入excel需要哪些字段这些字段取自UserReward类也就是上面的实际数据结果集的泛型
// listColumn = Arrays.asList("intercept", "slopes", "type","prediction");
//
//
// }else {
// list.add("逻辑回归常数项值");
// //具体需要写入excel需要哪些字段这些字段取自UserReward类也就是上面的实际数据结果集的泛型
// listColumn = Arrays.asList("intercept", "slope", "type","prediction");
// }
// list.add("回归类型");
// list.add("预测值");
//
// //表中第一行表头字段
// String[] headers=list.toArray(new String[list.size()]); // 使用 list.toArray() 转换为 String数组
//
//
// //实际数据结果集
// List<RAnalysisExportDTO> rAnalysisExportDTOS=new ArrayList<>();
// rAnalysisExportDTOS.add(rAnalysisExportDTO);
//
//
// try {
// FilePortUtil.exportExcel(response, title, headers, rAnalysisExportDTOS, listColumn);
// } catch (Exception e) {
// e.printStackTrace();
// }
// }
}

@ -5,6 +5,7 @@ import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.sztzjy.marketing.annotation.AnonymousAccess;
import com.sztzjy.marketing.entity.dto.CommentDTO;
import com.sztzjy.marketing.entity.dto.SentimentAnalyDTO;
import com.sztzjy.marketing.qianfan.util.Json;
import com.sztzjy.marketing.util.ResultEntity;
@ -17,8 +18,7 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.*;
import java.util.List;
import java.util.Map;
import java.util.*;
@RestController
@ -141,4 +141,234 @@ public class StuPythonController {
}
}
@PostMapping("/comment")
@ApiOperation("评论点抽取")
@AnonymousAccess
public ResultEntity comment(@RequestBody JSONObject text) {
// System.out.println(text);
String testText = text.getString("text");
String code = "# 导入所需依赖\n" +
"import pandas as pd\n" +
"import paddle\n" +
"from paddlenlp.transformers import SkepTokenizer, SkepModel\n" +
"from utils.utils import decoding, concate_aspect_and_opinion, format_print\n" +
"from utils import data_ext, data_cls\n" +
"from utils.model_define import SkepForTokenClassification, SkepForSequenceClassification\n" +
"\n" +
"\n" +
"# 单条文本情感分析预测函数\n" +
"def predict(input_text, ext_model, cls_model, tokenizer, ext_id2label, cls_id2label, max_length=512):\n" +
" ext_model.eval()\n" +
" cls_model.eval()\n" +
"\n" +
" # processing input text\n" +
" encoded_inputs = tokenizer(list(input_text), is_split_into_words=True, max_length=max_length,)\n" +
" input_ids = paddle.to_tensor([encoded_inputs[\"input_ids\"]])\n" +
" token_type_ids = paddle.to_tensor([encoded_inputs[\"token_type_ids\"]])\n" +
"\n" +
" # extract aspect and opinion words\n" +
" logits = ext_model(input_ids, token_type_ids=token_type_ids)\n" +
" predictions = logits.argmax(axis=2).numpy()[0]\n" +
" tag_seq = [ext_id2label[idx] for idx in predictions][1:-1]\n" +
" aps = decoding(input_text, tag_seq)\n" +
"\n" +
" # predict sentiment for aspect with cls_model\n" +
" results = []\n" +
" for ap in aps:\n" +
" aspect = ap[0]\n" +
" opinion_words = list(set(ap[1:]))\n" +
" aspect_text = concate_aspect_and_opinion(input_text, aspect, opinion_words)\n" +
" \n" +
" encoded_inputs = tokenizer(aspect_text, text_pair=input_text, max_length=max_length, return_length=True)\n" +
" input_ids = paddle.to_tensor([encoded_inputs[\"input_ids\"]])\n" +
" token_type_ids = paddle.to_tensor([encoded_inputs[\"token_type_ids\"]])\n" +
"\n" +
" logits = cls_model(input_ids, token_type_ids=token_type_ids)\n" +
" prediction = logits.argmax(axis=1).numpy()[0]\n" +
"\n" +
" result = {\"aspect\": aspect, \"opinions\": opinion_words, \"sentiment\": cls_id2label[prediction]}\n" +
" results.append(result)\n" +
"\n" +
" # print results\n" +
" format_print(results)\n" +
"\n" +
" # 返回预测结果\n" +
" return results\n" +
"\n" +
"# 批量情感分析预测函数\n" +
"def batchPredict(data, ext_model, cls_model, tokenizer, ext_id2label, cls_id2label, max_length=512):\n" +
"\n" +
" ext_model.eval()\n" +
" cls_model.eval()\n" +
"\n" +
" analysisResults = []\n" +
"\n" +
" # 针对批量文本逐条处理\n" +
" for input_text in data:\n" +
" # processing input text\n" +
" encoded_inputs = tokenizer(list(input_text), is_split_into_words=True, max_length=max_length,)\n" +
" input_ids = paddle.to_tensor([encoded_inputs[\"input_ids\"]])\n" +
" token_type_ids = paddle.to_tensor([encoded_inputs[\"token_type_ids\"]])\n" +
"\n" +
" # extract aspect and opinion words\n" +
" logits = ext_model(input_ids, token_type_ids=token_type_ids)\n" +
" predictions = logits.argmax(axis=2).numpy()[0]\n" +
" tag_seq = [ext_id2label[idx] for idx in predictions][1:-1]\n" +
" aps = decoding(input_text, tag_seq)\n" +
"\n" +
" # predict sentiment for aspect with cls_model\n" +
" results = []\n" +
" for ap in aps:\n" +
" aspect = ap[0]\n" +
" opinion_words = list(set(ap[1:]))\n" +
" aspect_text = concate_aspect_and_opinion(input_text, aspect, opinion_words)\n" +
" \n" +
" encoded_inputs = tokenizer(aspect_text, text_pair=input_text, max_length=max_length, return_length=True)\n" +
" input_ids = paddle.to_tensor([encoded_inputs[\"input_ids\"]])\n" +
" token_type_ids = paddle.to_tensor([encoded_inputs[\"token_type_ids\"]])\n" +
"\n" +
" logits = cls_model(input_ids, token_type_ids=token_type_ids)\n" +
" prediction = logits.argmax(axis=1).numpy()[0]\n" +
"\n" +
" result = {\"属性\": aspect, \"观点\": opinion_words, \"情感倾向\": cls_id2label[prediction]}\n" +
" results.append(result)\n" +
" singleResult = {\"text\": input_text, \"result\": str(results)}\n" +
" analysisResults.append(singleResult)\n" +
"\n" +
" # 返回预测结果 list形式\n" +
" return analysisResults\n" +
"\n" +
"\n" +
"label_ext_path = \"./label_ext.dict\"\n" +
"label_cls_path = \"./label_cls.dict\"\n" +
"# 加载PaddleNLP开源的基于全量数据训练好的评论观点抽取模型和属性级情感分类模型\n" +
"ext_model_path = \"./model/best_ext.pdparams\"\n" +
"cls_model_path = \"./model/best_cls.pdparams\"\n" +
"\n" +
"# load dict\n" +
"model_name = \"skep_ernie_1.0_large_ch\"\n" +
"ext_label2id, ext_id2label = data_ext.load_dict(label_ext_path)\n" +
"cls_label2id, cls_id2label = data_cls.load_dict(label_cls_path)\n" +
"tokenizer = SkepTokenizer.from_pretrained(model_name)\n" +
"print(\"label dict loaded.\")\n" +
"\n" +
"# load ext model 观点抽取模型\n" +
"ext_state_dict = paddle.load(ext_model_path)\n" +
"ext_skep = SkepModel.from_pretrained(model_name)\n" +
"ext_model = SkepForTokenClassification(ext_skep, num_classes=len(ext_label2id))\n" +
"ext_model.load_dict(ext_state_dict)\n" +
"print(\"extraction model loaded.\")\n" +
"\n" +
"# load cls model 属性级情感分析模型\n" +
"cls_state_dict = paddle.load(cls_model_path)\n" +
"cls_skep = SkepModel.from_pretrained(model_name)\n" +
"cls_model = SkepForSequenceClassification(cls_skep, num_classes=len(cls_label2id))\n" +
"cls_model.load_dict(cls_state_dict)\n" +
"print(\"classification model loaded.\")\n" +
"\n" +
"# 单条文本情感分析\n" +
"max_length = 512\n" +
"input_text = '经济实惠、动力不错、油耗低'\n" +
"predict(input_text, ext_model, cls_model, tokenizer, ext_id2label, cls_id2label, max_length=max_length)";
// 替换代码中的 test_text 为实际的 testText
String updatedCode = code.replace("'经济实惠、动力不错、油耗低'", "'" + testText + "'");
//System.out.println(updatedCode);
try {
String s = IdUtil.simpleUUID();
String tempPythonFile = "/usr/local/tianzeProject/digitalMarketing/comment/backend/" + s + ".py";
File file1 = new File(tempPythonFile);
// 确保父目录存在
File parentDir = file1.getParentFile();
if (!parentDir.exists()) {
System.out.println("Parent directory does not exist. Creating it.");
if (!parentDir.mkdirs()) {
System.out.println("Failed to create directories.");
return new ResultEntity<>(HttpStatus.INTERNAL_SERVER_ERROR);
}
}
// 创建文件并写入内容
if (!file1.exists()) {
try {
boolean fileCreated = file1.createNewFile();
if (fileCreated) {
System.out.println("File created successfully: " + tempPythonFile);
} else {
System.out.println("File already exists: " + tempPythonFile);
}
} catch (IOException e) {
e.printStackTrace();
return new ResultEntity<>(HttpStatus.INTERNAL_SERVER_ERROR);
}
}
try (PrintWriter out = new PrintWriter(file1)) {
out.println(updatedCode);
}
// 确认 Docker 命令
String[] command = {"docker", "exec", "pyexe", "python", tempPythonFile};
Process process = Runtime.getRuntime().exec(command);
// 获取进程的输入流
BufferedReader inputStream = new BufferedReader(new InputStreamReader(process.getInputStream()));
BufferedReader errorStream = new BufferedReader(new InputStreamReader(process.getErrorStream()));
// 读取 Python 代码的输出
StringBuilder output = new StringBuilder();
String line;
while ((line = inputStream.readLine()) != null) {
output.append(line).append("\n");
}
// 读取 Python 代码的错误信息
StringBuilder errors = new StringBuilder();
while ((line = errorStream.readLine()) != null) {
errors.append(line).append("\n");
}
// 等待进程执行完成
int exitCode = process.waitFor();
if (exitCode == 0) {
List<CommentDTO> comments = new ArrayList<>();
//取到输出值转为字符串
String strings = output.toString().trim();
String[] split = strings.split(",");
String aspect = split[0].split(": ")[1];
String opinionsStr = split[1].split(": ")[1];
String sentiment = split[2].split(": ")[1];
//将字符串转为集合
opinionsStr = opinionsStr.replace("'", "").replace("[", "").replace("]", "").trim();
List<String> opinions = Arrays.asList(opinionsStr.split(", "));
comments.add(new CommentDTO(aspect, opinions, sentiment));
// System.out.println("Python code output:\n" + output.toString());
return new ResultEntity(HttpStatus.OK, comments);
} else {
// System.err.println("Error executing Python code:\n" + errors.toString());
return new ResultEntity(HttpStatus.INTERNAL_SERVER_ERROR, errors.toString());
}
} catch (IOException | InterruptedException e) {
e.printStackTrace();
return new ResultEntity<>(HttpStatus.INTERNAL_SERVER_ERROR);
}
}
}

@ -0,0 +1,34 @@
package com.sztzjy.marketing.entity.dto;
import lombok.Data;
import java.util.List;
/**
* @author tz
* @date 2024/8/27 13:30
*/
@Data
public class CommentDTO {
private String aspect;
private List<String> opinions;
private String sentiment;
public CommentDTO(String aspect, List<String> opinions, String sentiment) {
this.aspect = aspect;
this.opinions = opinions;
this.sentiment = sentiment;
}
@Override
public String toString() {
return "CommentDTO{" +
"aspect='" + aspect + '\'' +
", opinions=" + opinions +
", sentiment='" + sentiment + '\'' +
'}';
}
}

@ -11,7 +11,7 @@ import org.ujmp.core.util.R;
@EqualsAndHashCode(callSuper = true)
@Data
public class LogisticPredictDTO extends RAnalysisDTO {
private double[] raX;
private double raX;
private double lX;

@ -11,10 +11,6 @@ import lombok.Data;
public class RAnalysisDTO {
@ApiModelProperty("回归系数值")
private double intercept;
@ApiModelProperty("线性回归常数项值")
private double[] slopes;
@ApiModelProperty("逻辑回归常数项值")
private double slope;
@ApiModelProperty("线性回归/逻辑回归")
private String type;
}

@ -314,7 +314,6 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
}
else if(tableName.equals(Constant.YONGHUXIAOSHOUNLB)){
// list=indicatorsMapper.getYHDLHYB();
list.add("id");
list.add("sales_volume");
list.add("sales_forehead");
@ -405,7 +404,10 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
Collections.rotate(attributes,-1);
}
}
return new ResultEntity(HttpStatus.OK,attributes);
// 过滤掉包含逗号的字符串值,并将负数转换为正数(对于数字类型)
List<Map<String, Object>> filteredAttributes = filterAndConvertAttributes(attributes);
return new ResultEntity(HttpStatus.OK,filteredAttributes);
}
List<String> fieldList = analyzeDataDTO.getFieldList();
@ -586,4 +588,37 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
}
return filteredList;
}
// 过滤掉包含逗号的字符串值,并将负数转换为正数(对于数字类型)
public static List<Map<String, Object>> filterAndConvertAttributes(List<Map<String, Object>> attributes) {
List<Map<String, Object>> filteredAndConvertedList = new ArrayList<>();
for (Map<String, Object> map : attributes) {
Map<String, Object> convertedMap = new HashMap<>();
for (Map.Entry<String, Object> entry : map.entrySet()) {
Object value = entry.getValue();
// 检查值是否为字符串且包含逗号
if (value instanceof String) {
// 去除字符串中的逗号
String strValue = (String) value;
value = strValue.replace(",", "").replace("-","");
}
// 检查值是否为数字类型且为负数
if (value instanceof Number && ((Number) value).doubleValue() < 0) {
// 转换为正数
value = Math.abs(((Number) value).doubleValue());
// 注意:这里为了保持类型一致,我们将其转换回原始类型(如果需要)
// 这里简化为double类型实际中可能需要更复杂的处理来保持原始类型
}
convertedMap.put(entry.getKey(), value);
}
filteredAndConvertedList.add(convertedMap);
}
return filteredAndConvertedList;
}
}

Loading…
Cancel
Save