Compare commits

..

2 Commits

Author SHA1 Message Date
@t2652009480 fc623db0ca 数字营销实训算法第五轮修改 7 months ago
@t2652009480 48588b59aa 数字营销实训算法第五轮修改 7 months ago

@ -10,6 +10,7 @@ import com.sztzjy.marketing.mapper.StuSpendingLevelMapper;
import com.sztzjy.marketing.mapper.StuUserCommentMapper;
import com.sztzjy.marketing.mapper.StuUserSalesAbilityMapper;
import com.sztzjy.marketing.service.StuDigitalMarketingModelService;
import com.sztzjy.marketing.util.BigDecimalUtils;
import com.sztzjy.marketing.util.DataConverter;
import com.sztzjy.marketing.util.ResultEntity;
import com.sztzjy.marketing.util.algorithm.Apriori;
@ -229,18 +230,17 @@ public class StuDigitalMarketingModelController {
@PostMapping("/logistic")
@AnonymousAccess
public ResultEntity logistic(@RequestBody LogisticDTO logisticDTO) {
double[][] x = logisticDTO.getX();
double[] x = logisticDTO.getX();
double[] y = logisticDTO.getY();
// 创建一个逻辑回归模型实例,0.1学习率
LogisticRegression model = new LogisticRegression(0.1);
model.train(x,y,1000); //训练模型
LogisticRegression lr = new LogisticRegression(x, y);
double[] coefficients = lr.fit();
DecimalFormat df = new DecimalFormat("#.0");
RAnalysisDTO rAnalysisDTO=new RAnalysisDTO();
rAnalysisDTO.setIntercept(Double.parseDouble(df.format(model.getIntercept())));
rAnalysisDTO.setSlope(Double.parseDouble(df.format(model.getCoefficient())));
rAnalysisDTO.setIntercept(Double.parseDouble(df.format(coefficients[0])));
rAnalysisDTO.setSlope(Double.parseDouble(df.format(coefficients[1])));
return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO);
}
@ -251,11 +251,9 @@ public class StuDigitalMarketingModelController {
@AnonymousAccess
public ResultEntity prediction(@RequestBody LogisticPredictDTO logisticPredictDTO) {
DecimalFormat df = new DecimalFormat("#.0");
// 创建一个逻辑回归模型实例,0.1学习率
LogisticRegression model = new LogisticRegression(0.1);
double predict = model.predict(logisticPredictDTO.getLX(), logisticPredictDTO.getSlope(), logisticPredictDTO.getIntercept());//预测
DecimalFormat df = new DecimalFormat("#.0");
BigDecimalUtils bigDecimalUtils=new BigDecimalUtils();
double predict = bigDecimalUtils.add(bigDecimalUtils.mul(logisticPredictDTO.getSlope(),logisticPredictDTO.getX()),logisticPredictDTO.getIntercept());
return new ResultEntity(HttpStatus.OK,"成功",df.format(predict));

@ -11,7 +11,7 @@ import lombok.Data;
public class LogisticDTO {
@ApiModelProperty("自变量x")
private double[][] x;
private double[] x;
@ApiModelProperty("因变量y")
private double[] y;

@ -11,8 +11,7 @@ import org.ujmp.core.util.R;
@EqualsAndHashCode(callSuper = true)
@Data
public class LogisticPredictDTO extends RAnalysisDTO {
private double raX;
private double lX;
private double x;
}

@ -21,6 +21,8 @@ import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.text.DecimalFormat;
import java.time.Instant;
@ -155,13 +157,16 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
List<Term> terms = HanLP.segment(string);
// 创建词汇表
Map<String, Integer> wordCount = new HashMap<>();
// 统计词频
// 读取停用词
Set<String> stopwords = this.loadStopwords("/usr/local/tianzeProject/digitalMarketing/jarAndDockerFile/停用词.txt");
// Set<String> stopwords = this.loadStopwords("D:\\project\\digital_marketing\\src\\main\\resources\\停用词.txt");
// 创建计数器
// 统计词频
int count = 0;
for (Term term : terms) {
String word = term.word.trim(); // 去除空格
if(!word.isEmpty()){
if (!word.isEmpty() && !stopwords.contains(word)) { // 过滤停用词
count++;
if (wordCount.containsKey(word)) {
wordCount.put(word, wordCount.get(word) + 1);
@ -173,20 +178,22 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
break;
}
}
// 将词汇表转换为列表,便于排序
List<Map.Entry<String, Integer>> wordList = new ArrayList<>(wordCount.entrySet());
// 按词频排序
wordList.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));
List<WordFrequencyDTO> list=new ArrayList<>();
List<WordFrequencyDTO> list = new ArrayList<>();
for (Map.Entry<String, Integer> entry : wordList) {
WordFrequencyDTO frequencyDTO=new WordFrequencyDTO();
WordFrequencyDTO frequencyDTO = new WordFrequencyDTO();
frequencyDTO.setKeyword(entry.getKey());
frequencyDTO.setFrequency(entry.getValue());
list.add(frequencyDTO);
}
return new ResultEntity<>(HttpStatus.OK,"成功",list);
return new ResultEntity<>(HttpStatus.OK, "成功", list);
}
// if(modelType.equals(Constant.WORD_CLOUD)){ //词云分析
@ -621,4 +628,16 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
return filteredAndConvertedList;
}
// 读取停用词
Set<String> loadStopwords(String filePath) throws IOException {
Set<String> stopwords = new HashSet<>();
try (BufferedReader br = new BufferedReader(new FileReader(filePath))) {
String line;
while ((line = br.readLine()) != null) {
stopwords.add(line.trim());
}
}
return stopwords;
}
}

@ -1,121 +1,44 @@
package com.sztzjy.marketing.util.algorithm;
import com.sztzjy.marketing.util.BigDecimalUtils;
import java.text.DecimalFormat;
public class LogisticRegression {
private double theta0; // 截距项
private double theta1; // 斜率
private double learningRate; // 学习率
private int currentEpoch; // 声明currentEpoch变量
public LogisticRegression(double learningRate) {
this.learningRate = learningRate;
}
private double[] x;
private double[] y;
public LogisticRegression(double learningRate, double initialTheta0, double initialTheta1) {
this.learningRate = learningRate;
this.theta0 = initialTheta0;
this.theta1 = initialTheta1;
this.currentEpoch = 0; // 在构造函数中初始化currentEpoch
}
// 添加一个方法来获取截距项
public double getIntercept() {
return theta0;
}
// 添加一个方法来获取第一个特征的系数
public double getCoefficient() {
return theta1;
}
// Sigmoid函数
private double sigmoid(double z) {
return 1 / (1 + Math.exp(-z));
public LogisticRegression(double[] x, double[] y) {
this.x = x;
this.y = y;
}
// 预测函数
public double predict(double x,double slopes,double intercept) {
BigDecimalUtils bigDecimalUtils=new BigDecimalUtils();
// 计算线性回归参数
public double[] fit() {
if (x.length != y.length || x.length == 0) {
throw new IllegalArgumentException("Invalid input data");
}
Double mul = bigDecimalUtils.mul(intercept, x);
int n = x.length;
double sumX = 0;
double sumY = 0;
double sumXY = 0;
double sumXX = 0;
// 计算各项和
for (int i = 0; i < n; i++) {
sumX += x[i];
sumY += y[i];
sumXY += x[i] * y[i];
sumXX += x[i] * x[i];
}
return bigDecimalUtils.add(slopes, mul);
// 计算斜率和截距
double slope = (n * sumXY - sumX * sumY) / (n * sumXX - sumX * sumX);
double intercept = (sumY - slope * sumX) / n;
return new double[] { intercept, slope };
}
private double dotProduct(double[] a, double[] b) {
double result = 0.0;
for (int i = 0; i < a.length; i++) {
result += a[i] * b[i];
}
return result;
}
//
// // 梯度下降更新参数
// public void updateParameters(double x, double y) {
// double h = predict(x);
// double error = y - h;
// theta1 += learningRate * error * h * (1 - h) * x;
// theta0 += learningRate * error * h * (1 - h);
// }
private double exponentialDecayLearningRate(double currentEpoch, int totalEpochs) {
return learningRate * Math.exp(-(currentEpoch / totalEpochs));
}
// 训练模型
public void train(double[][] data, double[] labels, int epochs) {
double learningRateCurrentEpoch = exponentialDecayLearningRate(currentEpoch, epochs);
for (int epoch = 0; epoch < epochs; epoch++) {
for (int i = 0; i < data.length; i++) {
double x = data[i][0];
double y = labels[i];
double h = sigmoid(theta0 + theta1 * x);
double error = y - h;
theta1 += learningRateCurrentEpoch * error * h * (1 - h) * x;
theta0 += learningRateCurrentEpoch * error * h * (1 - h);
}
// 可以添加打印语句查看训练情况
// System.out.println("Epoch: " + (epoch + 1) + ", Loss: " + calculateLoss(data, labels));
}
}
public double calculateLoss(double[][] data, double[] labels) {
double loss = 0;
for (int i = 0; i < data.length; i++) {
double x = data[i][0];
double y = labels[i];
double h = sigmoid(theta0 + theta1 * x);
loss += -y * Math.log(h) - (1 - y) * Math.log(1 - h);
}
return loss / data.length;
}
// public static void main(String[] args) {
// LogisticRegression model = new LogisticRegression(0.1); // 创建一个逻辑回归模型实例
//
//// double[][] data = {{1}, {2}, {3}, {4}, {5}};
//// double[] labels = {0, 0, 1, 1, 1}; // 假设这是一个二分类问题
//
// double[][] data = {{1, 19}, {1, 21}, {2, 20}, {2, 23}, {2, 31}};
// double[] labels = {15, 15, 16, 16, 17};
//
// model.train(data, labels, 1000); // 训练模型
//
// DecimalFormat df = new DecimalFormat("#.0");
//
// // 获取并打印截距项和系数
// double intercept = model.getIntercept();
// double coefficient = model.getCoefficient();
// System.out.println("常数项值: " + df.format(intercept));
// System.out.println("系数值 " + df.format(coefficient));
//
// double prediction = model.predict(2);
// System.out.println("x=3.5的预测: " + df.format(prediction));
// }
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save