Compare commits
2 Commits
5f5521b065
...
fc623db0ca
Author | SHA1 | Date |
---|---|---|
|
fc623db0ca | 7 months ago |
|
48588b59aa | 7 months ago |
@ -1,121 +1,44 @@
|
|||||||
package com.sztzjy.marketing.util.algorithm;
|
package com.sztzjy.marketing.util.algorithm;
|
||||||
|
|
||||||
import com.sztzjy.marketing.util.BigDecimalUtils;
|
|
||||||
|
|
||||||
import java.text.DecimalFormat;
|
|
||||||
|
|
||||||
public class LogisticRegression {
|
public class LogisticRegression {
|
||||||
private double theta0; // 截距项
|
|
||||||
private double theta1; // 斜率
|
|
||||||
private double learningRate; // 学习率
|
|
||||||
|
|
||||||
private int currentEpoch; // 声明currentEpoch变量
|
|
||||||
|
|
||||||
public LogisticRegression(double learningRate) {
|
|
||||||
this.learningRate = learningRate;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
private double[] x;
|
||||||
|
|
||||||
|
|
||||||
|
private double[] y;
|
||||||
|
|
||||||
public LogisticRegression(double learningRate, double initialTheta0, double initialTheta1) {
|
public LogisticRegression(double[] x, double[] y) {
|
||||||
this.learningRate = learningRate;
|
this.x = x;
|
||||||
this.theta0 = initialTheta0;
|
this.y = y;
|
||||||
this.theta1 = initialTheta1;
|
|
||||||
this.currentEpoch = 0; // 在构造函数中初始化currentEpoch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加一个方法来获取截距项
|
// 计算线性回归参数
|
||||||
public double getIntercept() {
|
public double[] fit() {
|
||||||
return theta0;
|
if (x.length != y.length || x.length == 0) {
|
||||||
|
throw new IllegalArgumentException("Invalid input data");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加一个方法来获取第一个特征的系数
|
int n = x.length;
|
||||||
public double getCoefficient() {
|
double sumX = 0;
|
||||||
return theta1;
|
double sumY = 0;
|
||||||
}
|
double sumXY = 0;
|
||||||
|
double sumXX = 0;
|
||||||
|
|
||||||
// Sigmoid函数
|
// 计算各项和
|
||||||
private double sigmoid(double z) {
|
for (int i = 0; i < n; i++) {
|
||||||
return 1 / (1 + Math.exp(-z));
|
sumX += x[i];
|
||||||
|
sumY += y[i];
|
||||||
|
sumXY += x[i] * y[i];
|
||||||
|
sumXX += x[i] * x[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
// 预测函数
|
// 计算斜率和截距
|
||||||
public double predict(double x,double slopes,double intercept) {
|
double slope = (n * sumXY - sumX * sumY) / (n * sumXX - sumX * sumX);
|
||||||
BigDecimalUtils bigDecimalUtils=new BigDecimalUtils();
|
double intercept = (sumY - slope * sumX) / n;
|
||||||
|
|
||||||
Double mul = bigDecimalUtils.mul(intercept, x);
|
|
||||||
|
|
||||||
return bigDecimalUtils.add(slopes, mul);
|
|
||||||
|
|
||||||
|
return new double[] { intercept, slope };
|
||||||
}
|
}
|
||||||
|
|
||||||
private double dotProduct(double[] a, double[] b) {
|
|
||||||
double result = 0.0;
|
|
||||||
for (int i = 0; i < a.length; i++) {
|
|
||||||
result += a[i] * b[i];
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
//
|
|
||||||
// // 梯度下降更新参数
|
|
||||||
// public void updateParameters(double x, double y) {
|
|
||||||
// double h = predict(x);
|
|
||||||
// double error = y - h;
|
|
||||||
// theta1 += learningRate * error * h * (1 - h) * x;
|
|
||||||
// theta0 += learningRate * error * h * (1 - h);
|
|
||||||
// }
|
|
||||||
private double exponentialDecayLearningRate(double currentEpoch, int totalEpochs) {
|
|
||||||
return learningRate * Math.exp(-(currentEpoch / totalEpochs));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 训练模型
|
|
||||||
public void train(double[][] data, double[] labels, int epochs) {
|
|
||||||
double learningRateCurrentEpoch = exponentialDecayLearningRate(currentEpoch, epochs);
|
|
||||||
for (int epoch = 0; epoch < epochs; epoch++) {
|
|
||||||
for (int i = 0; i < data.length; i++) {
|
|
||||||
double x = data[i][0];
|
|
||||||
double y = labels[i];
|
|
||||||
double h = sigmoid(theta0 + theta1 * x);
|
|
||||||
double error = y - h;
|
|
||||||
theta1 += learningRateCurrentEpoch * error * h * (1 - h) * x;
|
|
||||||
theta0 += learningRateCurrentEpoch * error * h * (1 - h);
|
|
||||||
}
|
|
||||||
// 可以添加打印语句查看训练情况
|
|
||||||
// System.out.println("Epoch: " + (epoch + 1) + ", Loss: " + calculateLoss(data, labels));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
public double calculateLoss(double[][] data, double[] labels) {
|
|
||||||
double loss = 0;
|
|
||||||
for (int i = 0; i < data.length; i++) {
|
|
||||||
double x = data[i][0];
|
|
||||||
double y = labels[i];
|
|
||||||
double h = sigmoid(theta0 + theta1 * x);
|
|
||||||
loss += -y * Math.log(h) - (1 - y) * Math.log(1 - h);
|
|
||||||
}
|
|
||||||
return loss / data.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
// public static void main(String[] args) {
|
|
||||||
// LogisticRegression model = new LogisticRegression(0.1); // 创建一个逻辑回归模型实例
|
|
||||||
//
|
|
||||||
//// double[][] data = {{1}, {2}, {3}, {4}, {5}};
|
|
||||||
//// double[] labels = {0, 0, 1, 1, 1}; // 假设这是一个二分类问题
|
|
||||||
//
|
|
||||||
// double[][] data = {{1, 19}, {1, 21}, {2, 20}, {2, 23}, {2, 31}};
|
|
||||||
// double[] labels = {15, 15, 16, 16, 17};
|
|
||||||
//
|
|
||||||
// model.train(data, labels, 1000); // 训练模型
|
|
||||||
//
|
|
||||||
// DecimalFormat df = new DecimalFormat("#.0");
|
|
||||||
//
|
|
||||||
// // 获取并打印截距项和系数
|
|
||||||
// double intercept = model.getIntercept();
|
|
||||||
// double coefficient = model.getCoefficient();
|
|
||||||
// System.out.println("常数项值: " + df.format(intercept));
|
|
||||||
// System.out.println("系数值 " + df.format(coefficient));
|
|
||||||
//
|
|
||||||
// double prediction = model.predict(2);
|
|
||||||
// System.out.println("x=3.5的预测: " + df.format(prediction));
|
|
||||||
// }
|
|
||||||
}
|
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue