Compare commits
2 Commits
5f5521b065
...
fc623db0ca
Author | SHA1 | Date |
---|---|---|
|
fc623db0ca | 7 months ago |
|
48588b59aa | 7 months ago |
@ -1,121 +1,44 @@
|
||||
package com.sztzjy.marketing.util.algorithm;
|
||||
|
||||
import com.sztzjy.marketing.util.BigDecimalUtils;
|
||||
|
||||
import java.text.DecimalFormat;
|
||||
|
||||
public class LogisticRegression {
|
||||
private double theta0; // 截距项
|
||||
private double theta1; // 斜率
|
||||
private double learningRate; // 学习率
|
||||
|
||||
private int currentEpoch; // 声明currentEpoch变量
|
||||
|
||||
public LogisticRegression(double learningRate) {
|
||||
this.learningRate = learningRate;
|
||||
}
|
||||
|
||||
private double[] x;
|
||||
|
||||
|
||||
private double[] y;
|
||||
|
||||
public LogisticRegression(double learningRate, double initialTheta0, double initialTheta1) {
|
||||
this.learningRate = learningRate;
|
||||
this.theta0 = initialTheta0;
|
||||
this.theta1 = initialTheta1;
|
||||
this.currentEpoch = 0; // 在构造函数中初始化currentEpoch
|
||||
public LogisticRegression(double[] x, double[] y) {
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
// 添加一个方法来获取截距项
|
||||
public double getIntercept() {
|
||||
return theta0;
|
||||
// 计算线性回归参数
|
||||
public double[] fit() {
|
||||
if (x.length != y.length || x.length == 0) {
|
||||
throw new IllegalArgumentException("Invalid input data");
|
||||
}
|
||||
|
||||
// 添加一个方法来获取第一个特征的系数
|
||||
public double getCoefficient() {
|
||||
return theta1;
|
||||
}
|
||||
int n = x.length;
|
||||
double sumX = 0;
|
||||
double sumY = 0;
|
||||
double sumXY = 0;
|
||||
double sumXX = 0;
|
||||
|
||||
// Sigmoid函数
|
||||
private double sigmoid(double z) {
|
||||
return 1 / (1 + Math.exp(-z));
|
||||
// 计算各项和
|
||||
for (int i = 0; i < n; i++) {
|
||||
sumX += x[i];
|
||||
sumY += y[i];
|
||||
sumXY += x[i] * y[i];
|
||||
sumXX += x[i] * x[i];
|
||||
}
|
||||
|
||||
// 预测函数
|
||||
public double predict(double x,double slopes,double intercept) {
|
||||
BigDecimalUtils bigDecimalUtils=new BigDecimalUtils();
|
||||
|
||||
Double mul = bigDecimalUtils.mul(intercept, x);
|
||||
|
||||
return bigDecimalUtils.add(slopes, mul);
|
||||
// 计算斜率和截距
|
||||
double slope = (n * sumXY - sumX * sumY) / (n * sumXX - sumX * sumX);
|
||||
double intercept = (sumY - slope * sumX) / n;
|
||||
|
||||
return new double[] { intercept, slope };
|
||||
}
|
||||
|
||||
private double dotProduct(double[] a, double[] b) {
|
||||
double result = 0.0;
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
result += a[i] * b[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
//
|
||||
// // 梯度下降更新参数
|
||||
// public void updateParameters(double x, double y) {
|
||||
// double h = predict(x);
|
||||
// double error = y - h;
|
||||
// theta1 += learningRate * error * h * (1 - h) * x;
|
||||
// theta0 += learningRate * error * h * (1 - h);
|
||||
// }
|
||||
private double exponentialDecayLearningRate(double currentEpoch, int totalEpochs) {
|
||||
return learningRate * Math.exp(-(currentEpoch / totalEpochs));
|
||||
}
|
||||
|
||||
// 训练模型
|
||||
public void train(double[][] data, double[] labels, int epochs) {
|
||||
double learningRateCurrentEpoch = exponentialDecayLearningRate(currentEpoch, epochs);
|
||||
for (int epoch = 0; epoch < epochs; epoch++) {
|
||||
for (int i = 0; i < data.length; i++) {
|
||||
double x = data[i][0];
|
||||
double y = labels[i];
|
||||
double h = sigmoid(theta0 + theta1 * x);
|
||||
double error = y - h;
|
||||
theta1 += learningRateCurrentEpoch * error * h * (1 - h) * x;
|
||||
theta0 += learningRateCurrentEpoch * error * h * (1 - h);
|
||||
}
|
||||
// 可以添加打印语句查看训练情况
|
||||
// System.out.println("Epoch: " + (epoch + 1) + ", Loss: " + calculateLoss(data, labels));
|
||||
}
|
||||
}
|
||||
public double calculateLoss(double[][] data, double[] labels) {
|
||||
double loss = 0;
|
||||
for (int i = 0; i < data.length; i++) {
|
||||
double x = data[i][0];
|
||||
double y = labels[i];
|
||||
double h = sigmoid(theta0 + theta1 * x);
|
||||
loss += -y * Math.log(h) - (1 - y) * Math.log(1 - h);
|
||||
}
|
||||
return loss / data.length;
|
||||
}
|
||||
|
||||
// public static void main(String[] args) {
|
||||
// LogisticRegression model = new LogisticRegression(0.1); // 创建一个逻辑回归模型实例
|
||||
//
|
||||
//// double[][] data = {{1}, {2}, {3}, {4}, {5}};
|
||||
//// double[] labels = {0, 0, 1, 1, 1}; // 假设这是一个二分类问题
|
||||
//
|
||||
// double[][] data = {{1, 19}, {1, 21}, {2, 20}, {2, 23}, {2, 31}};
|
||||
// double[] labels = {15, 15, 16, 16, 17};
|
||||
//
|
||||
// model.train(data, labels, 1000); // 训练模型
|
||||
//
|
||||
// DecimalFormat df = new DecimalFormat("#.0");
|
||||
//
|
||||
// // 获取并打印截距项和系数
|
||||
// double intercept = model.getIntercept();
|
||||
// double coefficient = model.getCoefficient();
|
||||
// System.out.println("常数项值: " + df.format(intercept));
|
||||
// System.out.println("系数值 " + df.format(coefficient));
|
||||
//
|
||||
// double prediction = model.predict(2);
|
||||
// System.out.println("x=3.5的预测: " + df.format(prediction));
|
||||
// }
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue