whb 7 months ago
commit 10a4000e5c

@ -74,13 +74,25 @@
<!-- 词频分析-->
<!-- <dependency>-->
<!-- <groupId>com.hankcs.hanlp.restful</groupId>-->
<!-- <artifactId>hanlp-restful</artifactId>-->
<!-- <version>0.0.12</version>-->
<!-- </dependency>-->
<!-- 情感分析-->
<dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.8.4</version>
</dependency>
<dependency>
<groupId>com.hankcs.hanlp.restful</groupId>
<artifactId>hanlp-restful</artifactId>
<version>0.0.12</version>
</dependency>
<!-- Apriori算法-->
<!-- <dependency>-->
<!-- <groupId>org.paukov</groupId>-->

@ -4,14 +4,13 @@ import com.sztzjy.marketing.annotation.AnonymousAccess;
import com.sztzjy.marketing.config.exception.handler.DigitalEconomyxception;
import com.sztzjy.marketing.entity.StuTrainingOperateStepExample;
import com.sztzjy.marketing.entity.StuTrainingOperateStepWithBLOBs;
import com.sztzjy.marketing.entity.dto.AssociationRulesDTO;
import com.sztzjy.marketing.entity.dto.ClusterAnalysisDTO;
import com.sztzjy.marketing.entity.dto.StatisticsDTO;
import com.sztzjy.marketing.entity.dto.*;
import com.sztzjy.marketing.service.StuDigitalMarketingModelService;
import com.sztzjy.marketing.util.ResultEntity;
import com.sztzjy.marketing.util.algorithm.Apriori;
import com.sztzjy.marketing.util.algorithm.KMeans;
import com.sztzjy.marketing.util.algorithm.LinearRegression;
import com.sztzjy.marketing.util.algorithm.LogisticRegression;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.swagger.annotations.ApiParam;
@ -24,9 +23,11 @@ import org.springframework.web.multipart.MultipartFile;
import javax.annotation.Resource;
import java.io.IOException;
import java.math.BigDecimal;
import java.text.DecimalFormat;
import java.util.*;
import static com.sztzjy.marketing.util.algorithm.KMeans.readIrisData;
import static com.sztzjy.marketing.util.algorithm.KMeansResult.*;
/**
* @author tz
@ -72,6 +73,18 @@ public class StuDigitalMarketingModelController {
}
@ApiOperation("数据预处理")
@PostMapping("/dataPreprocessing")
@AnonymousAccess
public ResultEntity dataPreprocessing(@RequestBody DataPreprocessingDTO dto) {
String userId = dto.getUserId();
String method = dto.getMethod();
List<Map<String, Object>> mapList = dto.getMapList();
return modelService.dataPreprocessing(userId,method,mapList);
}
@ApiOperation("描述性统计")
@PostMapping("/descriptiveStatistics")
@ -86,10 +99,10 @@ public class StuDigitalMarketingModelController {
}
@ApiOperation("聚类分析")
@PostMapping("/clusterAnalysis")
@ApiOperation("聚类分析--散点图")
@PostMapping("/clusterScatterPlot")
@AnonymousAccess
public ResultEntity clusterAnalysis(@ApiParam("簇数") Integer k,
public ResultEntity clusterScatterPlot(@ApiParam("簇数") Integer k,
@ApiParam("最大迭代次数") Integer t,
@ApiParam("最大迭代次数") String userId,
@RequestParam(required = false) @RequestPart MultipartFile file) {
@ -131,13 +144,37 @@ public class StuDigitalMarketingModelController {
@ApiOperation("聚类分析--分析结果")
@PostMapping("/clusterAnalysisResult")
@AnonymousAccess
public ResultEntity clusterAnalysisResult(@ApiParam("簇数") Integer k,
@ApiParam("最大迭代次数") Integer t,
@ApiParam("最大迭代次数") String userId,
@RequestParam(required = false) @RequestPart MultipartFile file) {
//初始化数据
ArrayList<ArrayList<Float>> arrayLists = readTable(file);
//随机选择 k 个数据点作为初始聚类中心
ArrayList<ArrayList<Float>> centerList = randomList(k, arrayLists);
//开始迭代
ArrayList<ArrayList<ArrayList<Float>>> kmeans = kmeans(k, t, centerList);
return new ResultEntity(HttpStatus.OK);
}
@ApiOperation("关联规则挖掘")
@PostMapping("/apriori")
@AnonymousAccess
public ResultEntity apriori(@ApiParam("最小支持度阀值") double support,
@ApiParam("最小置信度") double confidence,
@ApiParam("用户ID") String userId,
@RequestParam(required = false) @RequestPart MultipartFile file) {
@RequestParam(required = false) @RequestPart MultipartFile file) throws IOException {
//初始化事务数据库、项目集、候选集再进行剪枝
@ -154,21 +191,39 @@ public class StuDigitalMarketingModelController {
}
@ApiOperation("回归分析")
@ApiOperation("回归分析--线性回归")
@PostMapping("/regressionAnalysis")
@AnonymousAccess
public ResultEntity regressionAnalysis(@ApiParam("自变量X") double[] x,
@ApiParam("因变量Y") double[] y,
@ApiParam("用户ID") String userId) {
public ResultEntity regressionAnalysis(@RequestBody RegressionAnalysisDTO analysisDTO) {
double[][] x = analysisDTO.getX();
double[] y = analysisDTO.getY();
// 创建线性回归模型
LinearRegression regression=new LinearRegression();
regression.fit(x,y);
return new ResultEntity(HttpStatus.OK,"成功",regression);
}
// @ApiOperation("回归分析--逻辑回归")
// @PostMapping("/logistic")
// @AnonymousAccess
// public ResultEntity logistic(@RequestBody LogisticDTO logisticDTO) {
// double[] x = logisticDTO.getX();
// double[] y = logisticDTO.getY();
//
// // 创建线性回归模型
// LogisticRegression regression=new LogisticRegression();
// regression.fit(x,y);
//
//
// return new ResultEntity(HttpStatus.OK,"成功",regression);
// }
@ApiOperation("情感分析/文本挖掘")
@PostMapping("/emotionalAnalysis")
@AnonymousAccess

@ -0,0 +1,19 @@
package com.sztzjy.marketing.entity.dto;
import lombok.Data;
import java.util.List;
import java.util.Map;
/**
* @author tz
* @date 2024/6/25 14:21
*/
@Data
public class DataPreprocessingDTO {
private String userId;
private String method;
private List<Map<String,Object>> mapList;
}

@ -0,0 +1,21 @@
package com.sztzjy.marketing.entity.dto;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
/**
* @author tz
* @date 2024/7/23 16:52
*/
@Data
public class LogisticDTO {
@ApiModelProperty("自变量x")
private double[] x;
@ApiModelProperty("因变量y")
private double[] y;
@ApiModelProperty("用户ID")
private String userId;
}

@ -0,0 +1,22 @@
package com.sztzjy.marketing.entity.dto;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import java.util.List;
/**
* @author tz
* @date 2024/7/19 13:56
*/
@Data
public class RegressionAnalysisDTO {
@ApiModelProperty("自变量x")
private double[][] x;
@ApiModelProperty("因变量y")
private double[] y;
@ApiModelProperty("用户ID")
private String userId;
}

@ -22,4 +22,7 @@ public interface StuDigitalMarketingModelService {
List<String> viewMetrics(String userId, String tableName);
ResultEntity viewAnalyzeData(String userId, String tableName, List<String> fieldList);
ResultEntity dataPreprocessing(String userId, String method, List<Map<String, Object>> mapList);
}

@ -1,5 +1,6 @@
package com.sztzjy.marketing.service.impl;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.seg.common.Term;
import com.sztzjy.marketing.config.Constant;
@ -13,16 +14,16 @@ import com.sztzjy.marketing.service.StuDigitalMarketingModelService;
import com.sztzjy.marketing.util.ResultEntity;
import com.sztzjy.marketing.util.algorithm.DescriptiveStatisticsUtil;
import okhttp3.*;
import org.checkerframework.checker.units.qual.C;
import org.geolatte.geom.M;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*;
import static com.sztzjy.marketing.util.algorithm.BaiDuZhiNengYun.getAccessToken;
/**
* @author tz
* @date 2024/6/14 11:05
@ -69,18 +70,18 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
@Override
public ResultEntity emotionalAnalysis(String userId, String modelType, String content) throws IOException {
String commentUrl = "https://aip.baidubce.com/rpc/2.0/nlp/v2/comment_tag?access_token=24.88968c130db3ca9f266907b5004bec8f.2592000.1721270821.282335-83957582&charset=UTF-8";
String emoUrl="https://aip.baidubce.com/rpc/2.0/nlp/v1/sentiment_classify?access_token=24.88968c130db3ca9f266907b5004bec8f.2592000.1721270821.282335-83957582";
MediaType mediaType = MediaType.parse("application/json");
if(modelType.equals(Constant.COMMENT_EXTRACTION)){ //评论观点抽取
String commentUrl = "https://aip.baidubce.com/rpc/2.0/nlp/v2/comment_tag?charset=UTF-8&access_token="+ getAccessToken();
RequestBody body = RequestBody.create(mediaType, "{\"text\":\""+content+"\",\"type\":8}");
Request request = new Request.Builder()
.url(commentUrl)
.addHeader("Content-Type", "application/json")
.addHeader("Accept", "application/json")
.method("POST", body)
.build();
Response response = HTTP_CLIENT.newCall(request).execute();
@ -91,6 +92,7 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
}
if (modelType.equals(Constant.EMOTIONAL_TENDENCIES)) { //情感倾向分析
String emoUrl = "https://aip.baidubce.com/rpc/2.0/nlp/v1/sentiment_classify?charset=UTF-8&access_token=" + getAccessToken();
RequestBody body = RequestBody.create(mediaType, "{\"text\":\"" + content + "\"}");
@ -104,6 +106,7 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
return new ResultEntity(HttpStatus.OK, "成功", response.body().string());
}
}
@ -229,7 +232,6 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
}
statisticsDTO.setStatistics(statistics);
dtoList.add(statisticsDTO);
}
@ -280,9 +282,106 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
List<Map<String,Object>> attributes = tableNameMapper.selectByFields(fieldList,table);
//将日期类型转为string返回
for (Map<String, Object> row : attributes) {
for (Map.Entry<String, Object> entry : row.entrySet()) {
if (entry.getValue() instanceof LocalDateTime) {
entry.setValue(entry.getValue().toString());
}
}
}
// List<Object> createTime = attributes.get("create_time");
return new ResultEntity(HttpStatus.OK,attributes);
}
@Override
public ResultEntity dataPreprocessing(String userId, String method, List<Map<String, Object>> mapList) {
//数据去重
Set<String> set = new HashSet<>();
List<Map<String, Object>> deduplicatedDataList = new ArrayList<>();
for (Map<String, Object> map : mapList) {
String mapString = map.toString();
if (!set.contains(mapString)) {
set.add(mapString);
deduplicatedDataList.add(map);
}
}
//判断缺失值处理方式
if(method.equals("剔除数据")){
deduplicatedDataList.removeIf(record -> record.containsValue(null));
}
if(method.equals("均值代替")){
// // 均值代替处理
// for (Map<String, Object> record : deduplicatedDataList) {
// for (Map.Entry<String, Object> entry : record.entrySet()) {
// if (entry.getValue() == null) {
// double mean = calculateMean(deduplicatedDataList, entry.getKey());
// entry.setValue(mean);
// }
// }
// }
//众数代替
for (Map<String, Object> record : deduplicatedDataList) {
for (Map.Entry<String, Object> entry : record.entrySet()) {
if (entry.getValue() == null) {
String mode = calculateMode(deduplicatedDataList, entry.getKey());
entry.setValue(mode);
}
}
}
}
return new ResultEntity(HttpStatus.OK,"成功",deduplicatedDataList);
}
//均值代替计算方式
private static double calculateMean(List<Map<String, Object>> dataList, String key) {
double sum = 0;
int count = 0;
for (Map<String, Object> record : dataList) {
Object value = record.get(key);
if (value != null && value instanceof Number) {
sum += ((Number) value).doubleValue();
count++;
}
}
return count > 0 ? (sum / count) : 0;
}
//众数代替计算方式
private static String calculateMode(List<Map<String, Object>> dataList, String key) {
Map<String, Integer> counts = new HashMap<>();
for (Map<String, Object> record : dataList) {
Object value = record.get(key);
if (value != null && value instanceof String) {
String stringValue = (String) value;
counts.put(stringValue, counts.getOrDefault(stringValue, 0) + 1);
}
}
int maxCount = 0;
String mode = null;
for (Map.Entry<String, Integer> entry : counts.entrySet()) {
if (entry.getValue() > maxCount) {
maxCount = entry.getValue();
mode = entry.getKey();
}
}
return mode != null ? mode : ""; // 如果存在众数则返回,否则返回空字符串
}
}

@ -1,10 +1,21 @@
package com.sztzjy.marketing.util.algorithm;
import com.alibaba.excel.EasyExcel;
import com.alibaba.excel.context.AnalysisContext;
import com.alibaba.excel.event.AnalysisEventListener;
import com.sztzjy.marketing.entity.dto.AssociationRulesDTO;
import org.apache.poi.ss.usermodel.Cell;
import org.apache.poi.ss.usermodel.Row;
import org.apache.poi.ss.usermodel.Sheet;
import org.apache.poi.ss.usermodel.Workbook;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;
import org.springframework.web.multipart.MultipartFile;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.*;
public class Apriori {
@ -20,29 +31,25 @@ public class Apriori {
static HashMap<ArrayList<String>, Integer> L_ALL = new HashMap<ArrayList<String>, Integer>();
//从文件中读取内容,返回事务集
public static ArrayList<ArrayList<String>> readTable(MultipartFile file) {
ArrayList<ArrayList<String>> t = new ArrayList<ArrayList<String>>();
ArrayList<String> t1 = null;
try {
InputStreamReader isr = new InputStreamReader(file.getInputStream());
BufferedReader bf = new BufferedReader(isr);
String str = null;
while((str = bf.readLine()) != null) {
t1 = new ArrayList<String>();
String[] str1 = str.split(",");
for(int i = 1; i < str1.length ; i++) {
t1.add(str1[i]);
}
t.add(t1);
}
bf.close();
isr.close();
} catch (Exception e) {
ArrayList<ArrayList<String>> t = new ArrayList<>();
try (InputStream inputStream = file.getInputStream();
Workbook workbook = new XSSFWorkbook(inputStream)) {
Sheet sheet = workbook.getSheetAt(0);
for (Row row : sheet) {
ArrayList<String> rowData = new ArrayList<>();
for (Cell cell : row) {
rowData.add(cell.getStringCellValue());
}
t.add(rowData);
}
} catch (IOException e) {
e.printStackTrace();
System.out.println("文件读取失败!");
}
System.out.println("\nD:" + t);
return t;
}
//剪枝步从候选集C中删除小于最小支持度的并放入频繁集L中
public static void pruning(HashMap<ArrayList<String>, Integer> C,HashMap<ArrayList<String>, Integer> L,double min_support) {
L.clear();
@ -63,7 +70,7 @@ public class Apriori {
/**
*
*/
public static void init(MultipartFile file,double min_support) {
public static void init(MultipartFile file,double min_support) throws IOException {
// //将文件中的数据放入集合D中
D = readTable(file);
// 扫描事务数据库。生成项目集,支持度=该元素在事务数据库出现的次数/事务数据库的事务数

@ -0,0 +1,39 @@
package com.sztzjy.marketing.util.algorithm;
import okhttp3.*;
import org.json.JSONObject;
import java.io.IOException;
/**
* @author tz
* @date 2024/7/19 15:25
*/
public class BaiDuZhiNengYun {
public static final String API_KEY = "12hSC65yXQdFndZridHTaHtd";
public static final String SECRET_KEY = "0SegCsSEahKBYB4z6ZLELebLMRt37gib";
static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
/**
* AKSKAccess Token
*
* @return Access Token
* @throws IOException IO
*/
public static String getAccessToken() throws IOException {
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType, "grant_type=client_credentials&client_id=" + API_KEY
+ "&client_secret=" + SECRET_KEY);
Request request = new Request.Builder()
.url("https://aip.baidubce.com/oauth/2.0/token")
.method("POST", body)
.addHeader("Content-Type", "application/x-www-form-urlencoded")
.build();
Response response = HTTP_CLIENT.newCall(request).execute();
if (response.body() != null) {
return new JSONObject(response.body().string()).getString("access_token");
}
return null;
}
}

@ -5,7 +5,7 @@ import java.io.*;
public class DataDeal {
static double min_support = 4; //最小支持度
static double min_confident = 0.7; //最小置信度
//文件路径,我的是在当前项目下,大家换成自己的文件路径就是了
//文件路径
static String filePath = "D:\\home\\marketing\\1.txt";
static ArrayList<ArrayList<String>> D = new ArrayList<ArrayList<String>>();//事务数据库D
static HashMap<ArrayList<String>, Integer> C = new HashMap<ArrayList<String>, Integer>();//项目集C

@ -171,4 +171,6 @@ public class KMeans {
}
return points;
}
}

@ -0,0 +1,243 @@
package com.sztzjy.marketing.util.algorithm;
import org.springframework.web.multipart.MultipartFile;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
/**
* @author tz
* @date 2024/7/2 17:37
*/
public class KMeansResult {
// 记录迭代的次数
static int count = 1;
// 文件所在路径
// static String filePath = "D:\\home\\marketing\\Segmentation_dataset(1).csv";
// 储存从文件中读取的数据
static ArrayList<ArrayList<Float>> table = new ArrayList<ArrayList<Float>>();
// 储存分类一的结果
static ArrayList<ArrayList<Float>> alist = new ArrayList<ArrayList<Float>>();
// 储存分类二的结果
static ArrayList<ArrayList<Float>> blist = new ArrayList<ArrayList<Float>>();
// 储存分类三的结果
static ArrayList<ArrayList<Float>> clist = new ArrayList<ArrayList<Float>>();
// 记录初始随机产生的3个聚类中心
static ArrayList<ArrayList<Float>> randomList = new ArrayList<ArrayList<Float>>();
// 读取文件中的数据,储存到集合中
public static ArrayList<ArrayList<Float>> readTable(MultipartFile file){
table.clear();
ArrayList<Float> d = null;
try (BufferedReader br = new BufferedReader(new InputStreamReader(file.getInputStream()))) {
// 跳过标题行
br.readLine();
String str = null;
while((str = br.readLine()) != null) {
d = new ArrayList<Float>();
String[] str1 = str.split(",");
for(int i = 0; i < str1.length ; i++) {
d.add(Float.parseFloat(str1[i]));
}
table.add(d);
}
// System.out.println(table);
} catch (Exception e) {
e.printStackTrace();
System.out.println("文件不存在!");
}
return table;
}
// 随机产生k个初始聚类中心
public static ArrayList<ArrayList<Float>> randomList(Integer k,ArrayList<ArrayList<Float>> randomList) {
ArrayList<ArrayList<Float>> centerList=new ArrayList<>();
int[] list = new int[k];
// Generate k distinct random numbers
do {
for (int i = 0; i < k; i++) {
list[i] = (int)(Math.random() * 30);
}
} while (hasDuplicates(list));
// For testing purposes, we can take the first k elements from the data set as initial cluster centers
for (int i = 0; i < k; i++) {
centerList.add(table.get(list[i]));
}
return centerList;
}
// helper方法来检查数组中是否包含重复项
private static boolean hasDuplicates(int[] arr) {
Set<Integer> set = new HashSet<>();
for (int val : arr) {
if (!set.add(val)) {
return true;
}
}
return false;
}
//比较两个数的大小,并返回其中较小的数
public static double minNumber(double x, double y) {
if(x < y) {
return x;
}
return y;
}
// 计算各个数据到中心点的距离然后分成k类
public static void eudistance(ArrayList<ArrayList<Float>> list, int k) {
ArrayList<ArrayList<ArrayList<Float>>> clusters = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
clusters.add(new ArrayList<>());
}
for (int i = 0; i < table.size(); i++) {
double[] distances = new double[k];
for (int j = 0; j < k; j++) {
distances[j] = calculateEuclideanDistance(table.get(i), list.get(j));
}
int closestCluster = findClosestCluster(distances);
clusters.get(closestCluster).add(table.get(i));
}
System.out.println("第" + count + "次迭代:");
for (int i = 0; i < k; i++) {
System.out.println("质心 " + i + ": " + clusters.get(i));
}
System.out.println();
count++;
}
//计算两个数据点之间的欧几里得距离
private static double calculateEuclideanDistance(ArrayList<Float> point1, ArrayList<Float> point2) {
double sum = 0;
for (int i = 1; i < point1.size(); i++) {
sum += Math.pow(point1.get(i) - point2.get(i), 2);
}
return Math.sqrt(sum);
}
//根据距离数组找到最近的簇的索引
private static int findClosestCluster(double[] distances) {
double minDistance = Double.MAX_VALUE;
int closestCluster = 0;
for (int i = 0; i < distances.length; i++) {
if (distances[i] < minDistance) {
minDistance = distances[i];
closestCluster = i;
}
}
return closestCluster;
}
//最终聚类结果
public static ArrayList<ArrayList<ArrayList<Float>>> kmeans(int k,Integer t,ArrayList<ArrayList<Float>> centroids) {
ArrayList<ArrayList<ArrayList<Float>>> clusters = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
clusters.add(new ArrayList<>());
}
boolean converged = false;
int iterations = 0;
while (!converged && iterations < t) {
for (int i = 0; i < table.size(); i++) {
ArrayList<Float> point = table.get(i);
int closestCluster = findClosestCluster(point, centroids);
clusters.get(closestCluster).add(point);
}
ArrayList<ArrayList<Float>> newCentroids = new ArrayList<>(k);
for (int i = 0; i < k; i++) {
newCentroids.add(calculateCentroid(clusters.get(i)));
}
converged = checkConvergence(centroids, newCentroids);
if(!converged){
centroids = newCentroids;
for (int i = 0; i < k; i++) {
clusters.get(i).clear();
}
iterations++;
continue;
}else {
break;
}
}
System.out.println("最终聚类结果:");
for (int i = 0; i < k; i++) {
System.out.println("簇 " + i + ": " + clusters.get(i));
}
return clusters;
}
//计算一个簇的质心
private static ArrayList<Float> calculateCentroid(ArrayList<ArrayList<Float>> cluster) {
float sum1 = 0, sum2 = 0, sum3 = 0, sum4 = 0;
for (ArrayList<Float> point : cluster) {
sum1 += point.get(1);
sum2 += point.get(2);
sum3 += point.get(3);
sum4 += point.get(4);
}
float size = cluster.size();
ArrayList<Float> centroid = new ArrayList<>();
centroid.add(0f);
centroid.add(sum1 / size);
centroid.add(sum2 / size);
centroid.add(sum3 / size);
centroid.add(sum4 / size);
return centroid;
}
//检查是否收敛,如果没有收敛,则用新质心替换旧质心,并清空所有簇,准备进行下一轮迭代
private static boolean checkConvergence(ArrayList<ArrayList<Float>> oldCentroids, ArrayList<ArrayList<Float>> newCentroids) {
for (int i = 0; i < oldCentroids.size(); i++) {
if (!oldCentroids.get(i).equals(newCentroids.get(i))) {
return false;
}
}
return true;
}
private static int findClosestCluster(ArrayList<Float> point, ArrayList<ArrayList<Float>> centroids) {
double minDistance = Double.MAX_VALUE;
int closestCluster = 0;
for (int i = 0; i < centroids.size(); i++) {
double distance = calculateEuclideanDistance(point, centroids.get(i));
if (distance < minDistance) {
minDistance = distance;
closestCluster = i;
}
}
return closestCluster;
}
public static void main(String[] args) {
// ArrayList<ArrayList<Float>> rlist = new ArrayList<ArrayList<Float>>();
// readTable(filePath);
// rlist = randomList();
// eudistance(rlist);
// kmeans();
}
}

@ -1,31 +1,61 @@
package com.sztzjy.marketing.util.algorithm;
import java.text.DecimalFormat;
public class LinearRegression {
private double intercept;
private double slope;
private double[] slopes;
// 线性回归的训练方法
public void fit(double[] x, double[] y) {
int n = x.length;
double sumX = 0.0, sumY = 0.0, sumXY = 0.0, sumX2 = 0.0;
public void fit(double[][] x, double[] y) {
int n = y.length;
int m = x[0].length;
double[] sumX = new double[m];
double sumY = 0.0;
double[] sumXY = new double[m];
double[] sumX2 = new double[m];
for (int i = 0; i < n; i++) {
sumX += x[i];
for (int j = 0; j < m; j++) {
sumX[j] += x[i][j];
sumXY[j] += x[i][j] * y[i];
sumX2[j] += x[i][j] * x[i][j];
}
sumY += y[i];
sumXY += x[i] * y[i];
sumX2 += x[i] * x[i];
}
double xMean = sumX / n;
double[] xMean = new double[m];
for (int j = 0; j < m; j++) {
xMean[j] = sumX[j] / n;
}
double yMean = sumY / n;
this.slope = (sumXY - n * xMean * yMean) / (sumX2 - n * xMean * xMean);
this.intercept = yMean - this.slope * xMean;
DecimalFormat df = new DecimalFormat("#.00");
this.slopes = new double[m];
for (int j = 0; j < m; j++) {
this.slopes[j] = Double.parseDouble(df.format((sumXY[j] - n * xMean[j] * yMean) / (sumX2[j] - n * xMean[j] * xMean[j])));
}
this.intercept = Double.parseDouble(df.format(yMean - dotProduct(this.slopes, xMean)));
}
// 辅助方法:计算向量的点积
private double dotProduct(double[] a, double[] b) {
double result = 0.0;
for (int i = 0; i < a.length; i++) {
result += a[i] * b[i];
}
return result;
}
// 预测方法
public double predict(double x) {
return this.intercept + this.slope * x;
public double predict(double[] x) {
if (x.length != this.slopes.length) {
throw new IllegalArgumentException("输入的x维度与模型不匹配");
}
return this.intercept + dotProduct(this.slopes, x);
}
// 获取截距
@ -34,14 +64,14 @@ public class LinearRegression {
}
// 获取斜率
public double getSlope() {
return this.slope;
public double[] getSlopes() {
return this.slopes;
}
public static void main(String[] args) {
// 示例数据
double[] x = {1, 2, 3, 4, 5};
double[] y = {2, 3, 5, 6, 8};
double[][] x = {{1, 19}, {1, 21}, {2, 20}, {2, 23}, {2, 31}};
double[] y = {15, 15, 16, 16, 17};
// 创建线性回归模型
LinearRegression lr = new LinearRegression();
@ -49,11 +79,11 @@ public class LinearRegression {
// 输出模型参数
System.out.println("截距: " + lr.getIntercept());
System.out.println("斜率: " + lr.getSlope());
System.out.println("斜率: " + java.util.Arrays.toString(lr.getSlopes()));
// 预测新数据
double newX = 6;
double[] newX = {2, 45};
double predictedY = lr.predict(newX);
System.out.println("输入x: " + newX + "预测y=" + predictedY);
System.out.println("输入x: " + java.util.Arrays.toString(newX) + "预测y=" + predictedY);
}
}

@ -1,196 +1,71 @@
package com.sztzjy.marketing.util.algorithm;
import java.util.Arrays;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
/**
* https://blog.csdn.net/weixin_45040801/article/details/102542209
* https://blog.csdn.net/ccblogger/article/details/81739200
* */
public class LogisticRegression {
private double[] weights;
private double bias;
/**
*
*
* @param data
* @param classValues
* @return
*/
public static double[] train(double[][] data, double[] classValues) {
if (data != null && classValues != null && data.length == classValues.length) {
// 期望矩阵
// Matrix matrWeights = DenseMatrix.Factory.zeros(data[0].length + 1, 1);
Matrix matrWeights = DenseMatrix.Factory.zeros(data[0].length, 1);
System.out.println("data[0].length + 1========"+data[0].length + 1);
// 数据矩阵
// Matrix matrData = DenseMatrix.Factory.zeros(data.length, data[0].length + 1);
Matrix matrData = DenseMatrix.Factory.zeros(data.length, data[0].length);
// 标志矩阵
Matrix matrLable = DenseMatrix.Factory.zeros(data.length, 1);
// 训练速率矩阵
// Matrix matrRate = DenseMatrix.Factory.zeros(data[0].length + 1,data.length);
// 统计difference的总体失误的辅助矩阵
Matrix matrDiffUtil = DenseMatrix.Factory.zeros(data[0].length,data.length);
for(int i=0;i<data.length;i++){
for(int j=0;j<data[0].length;j++) {
matrDiffUtil.setAsDouble(1, j, i);
public LogisticRegression(int numFeatures) {
weights = new double[numFeatures];
bias = 0;
}
}
System.out.println("matrDiffUtil="+matrDiffUtil);
/**
*
*
* */
// System.out.println("matrRate======="+matrRate);
// 设置训练速率矩阵
// for(int i=0;i<data[0].length + 1;i++){
// for(int j=0;j<data.length;j++){
// }
// }
for (int i = 0; i < data.length; i++) {
// matrData.setAsDouble(1.0, i, 0);
// 初始化标志矩阵
matrLable.setAsDouble(classValues[i], i, 0);
for (int j = 0; j < data[0].length; j++) {
// 初始化数据矩阵
// matrData.setAsDouble(data[i][j], i, j + 1);
matrData.setAsDouble(data[i][j], i, j);
if (i == 0) {
// 初始化期望矩阵
// matrWeights.setAsDouble(1.0, j+1, 0);
matrWeights.setAsDouble(1.0, j, 0);
}
}
}
// matrWeights.setAsDouble(-0.5, data[0].length, 0);
// matrRate = matrData.transpose().times(0.9);
// System.out.println("matrRate============"+matrRate);
public void fit(double[][] X, double[] y, double learningRate, int numIterations) {
int numSamples = X.length;
int numFeatures = X[0].length;
double step = 0.011;
int maxCycle = 5000000;
// int maxCycle = 5;
for (int iter = 0; iter < numIterations; iter++) {
double[] gradientWeights = new double[numFeatures];
double gradientBias = 0;
System.out.println("matrData======"+matrData);
System.out.println("matrWeights"+matrWeights);
System.out.println("matrLable"+matrLable);
for (int i = 0; i < numSamples; i++) {
double dotProduct = dotProduct(X[i], weights) + bias;
double prediction = sigmoid(dotProduct);
double error = y[i] - prediction;
/**
* 使
* https://blog.csdn.net/lionel_fengj/article/details/53400715
*
* */
for (int i = 0; i < maxCycle; i++) {
// 将想要函数转换为sigmoid函数并得到的值
Matrix h = sigmoid(matrData.mtimes(matrWeights));
// System.out.println("h=="+h);
// 求出预期和真实的差值
Matrix difference = matrLable.minus(h);
// System.out.println("difference "+difference);
// matrData转置后和difference相乘得到预期和真实差值的每一个值
// 公式:@0 = @0 - ax(y0-y),可以参考https://blog.csdn.net/ccblogger/article/details/81739200
matrWeights = matrWeights.plus(matrData.transpose().mtimes(difference).times(step));
// matrWeights = matrWeights.plus(matrRate.mtimes(difference).times(step));
// matrWeights = matrWeights.plus(matrDiffUtil.mtimes(difference).times(step));
for (int j = 0; j < numFeatures; j++) {
gradientWeights[j] += error * X[i][j];
}
double[] rtn = new double[(int) matrWeights.getRowCount()];
for (long i = 0; i < matrWeights.getRowCount(); i++) {
rtn[(int) i] = matrWeights.getAsDouble(i, 0);
gradientBias += error;
}
return rtn;
for (int j = 0; j < numFeatures; j++) {
weights[j] += learningRate * gradientWeights[j] / numSamples;
}
return null;
bias += learningRate * gradientBias / numSamples;
}
/**
* sigmoid
*
* @param sourceMatrix
* @return sigmoid
*/
public static Matrix sigmoid(Matrix sourceMatrix) {
Matrix rtn = DenseMatrix.Factory.zeros(sourceMatrix.getRowCount(), sourceMatrix.getColumnCount());
for (int i = 0; i < sourceMatrix.getRowCount(); i++) {
for (int j = 0; j < sourceMatrix.getColumnCount(); j++) {
rtn.setAsDouble(sigmoid(sourceMatrix.getAsDouble(i, j)), i, j);
}
public double predict(double[] x) {
double dotProduct = dotProduct(x, weights) + bias;
return sigmoid(dotProduct);
}
return rtn;
private double sigmoid(double x) {
return 1 / (1 + Math.exp(-x));
}
/**
* sigmoid
*
* @param source
* @return sigmoid
*/
public static double sigmoid(double source) {
return 1.0 / (1 + Math.exp(-1 * source));
private double dotProduct(double[] x, double[] y) {
double sum = 0;
for (int i = 0; i < x.length; i++) {
sum += x[i] * y[i];
}
// 测试预测值:
/**
*
*
* @param sourceData
* @param model
* @return
*/
public static double getValue(double[] sourceData, double[] model) {
// double logisticRegressionValue = model[0];
double logisticRegressionValue = 0;
for (int i = 0; i < sourceData.length; i++) {
// logisticRegressionValue = logisticRegressionValue + sourceData[i] * model[i + 1];
logisticRegressionValue = logisticRegressionValue + sourceData[i] * model[i];
return sum;
}
logisticRegressionValue = sigmoid(logisticRegressionValue);
return logisticRegressionValue;
}
public static void main(String[] args) {
double[][] X = {{1, 1}, {1, 2}, {2, 1}, {2, 3}};
double[] y = {0, 0, 0, 1};
LogisticRegression model = new LogisticRegression(2);
model.fit(X, y, 0.01, 10000);
System.out.println("Weights: " + Arrays.toString(model.weights));
System.out.println("Bias: " + model.bias);
public static void main(String[] args) {
double[][] sourceData = new double[][]{{-1, 1}, {0, 1}, {1, -1}, {1, 0}, {0, 0.1}, {0, -0.1}, {-1, -1.1}, {1, 0.9}};
double[] classValue = new double[]{1, 1, 0, 0, 1, 0, 0, 0};
double[] modle = LogisticRegression.train(sourceData, classValue);
// double logicValue = LogisticRegression.getValue(new double[] { 0, 0 }, modle);
double logicValue0 = LogisticRegression.getValue(new double[]{-1, 1}, modle);
double logicValue1 = LogisticRegression.getValue(new double[]{0, 1}, modle);
double logicValue2 = LogisticRegression.getValue(new double[]{1, -1}, modle);//3.1812246935599485E-60无限趋近于0
double logicValue3 = LogisticRegression.getValue(new double[]{1, 0}, modle);//3.091713602147872E-30无限趋近于0
double logicValue4 = LogisticRegression.getValue(new double[]{0, 0.1}, modle);//3.091713602147872E-30无限趋近于0
double logicValue5 = LogisticRegression.getValue(new double[]{0, -0.1}, modle);//3.091713602147872E-30无限趋近于0
System.out.println("---model---");
for (int i = 0; i < modle.length; i++) {
System.out.println(modle[i]);
}
System.out.println("-----------");
// System.out.println(logicValue);
System.out.println(logicValue0);
System.out.println(logicValue1);
System.out.println(logicValue2);
System.out.println(logicValue3);
System.out.println(logicValue4);
System.out.println(logicValue5);
double[] newSample = {1, 2};
double prediction = model.predict(newSample);
System.out.println("Prediction for " + Arrays.toString(newSample) + ": " + prediction);
}
}

@ -0,0 +1,21 @@
package com.sztzjy.marketing.util.algorithm;
import cn.hutool.core.io.resource.ClassPathResource;
import com.hankcs.hanlp.corpus.io.IIOAdapter;
import java.io.*;
import java.nio.file.Files;
public class ResourceFileIoAdapter implements IIOAdapter {
@Override
public InputStream open(String path) throws IOException {
ClassPathResource resource = new ClassPathResource(path);
return Files.newInputStream(resource.getFile().toPath());
}
@Override
public OutputStream create(String path) throws IOException {
ClassPathResource resource = new ClassPathResource(path);
return Files.newOutputStream(resource.getFile().toPath());
}
}
Loading…
Cancel
Save