diff --git a/pom.xml b/pom.xml index ba122c9..92428ca 100644 --- a/pom.xml +++ b/pom.xml @@ -74,13 +74,25 @@ + + + + + + + + + com.hankcs + hanlp + portable-1.8.4 + + com.hankcs.hanlp.restful hanlp-restful 0.0.12 - diff --git a/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java b/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java index b63c37a..22a963f 100644 --- a/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java +++ b/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java @@ -4,14 +4,13 @@ import com.sztzjy.marketing.annotation.AnonymousAccess; import com.sztzjy.marketing.config.exception.handler.DigitalEconomyxception; import com.sztzjy.marketing.entity.StuTrainingOperateStepExample; import com.sztzjy.marketing.entity.StuTrainingOperateStepWithBLOBs; -import com.sztzjy.marketing.entity.dto.AssociationRulesDTO; -import com.sztzjy.marketing.entity.dto.ClusterAnalysisDTO; -import com.sztzjy.marketing.entity.dto.StatisticsDTO; +import com.sztzjy.marketing.entity.dto.*; import com.sztzjy.marketing.service.StuDigitalMarketingModelService; import com.sztzjy.marketing.util.ResultEntity; import com.sztzjy.marketing.util.algorithm.Apriori; import com.sztzjy.marketing.util.algorithm.KMeans; import com.sztzjy.marketing.util.algorithm.LinearRegression; +import com.sztzjy.marketing.util.algorithm.LogisticRegression; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiParam; @@ -24,9 +23,11 @@ import org.springframework.web.multipart.MultipartFile; import javax.annotation.Resource; import java.io.IOException; import java.math.BigDecimal; +import java.text.DecimalFormat; import java.util.*; import static com.sztzjy.marketing.util.algorithm.KMeans.readIrisData; +import static com.sztzjy.marketing.util.algorithm.KMeansResult.*; /** * @author tz @@ -72,6 +73,18 @@ public class StuDigitalMarketingModelController { } + @ApiOperation("数据预处理") + @PostMapping("/dataPreprocessing") + @AnonymousAccess + public ResultEntity dataPreprocessing(@RequestBody DataPreprocessingDTO dto) { + + String userId = dto.getUserId(); + String method = dto.getMethod(); + List> mapList = dto.getMapList(); + + return modelService.dataPreprocessing(userId,method,mapList); + } + @ApiOperation("描述性统计") @PostMapping("/descriptiveStatistics") @@ -86,13 +99,13 @@ public class StuDigitalMarketingModelController { } - @ApiOperation("聚类分析") - @PostMapping("/clusterAnalysis") + @ApiOperation("聚类分析--散点图") + @PostMapping("/clusterScatterPlot") @AnonymousAccess - public ResultEntity clusterAnalysis(@ApiParam("簇数") Integer k, - @ApiParam("最大迭代次数") Integer t, - @ApiParam("最大迭代次数") String userId, - @RequestParam(required = false) @RequestPart MultipartFile file) { + public ResultEntity clusterScatterPlot(@ApiParam("簇数") Integer k, + @ApiParam("最大迭代次数") Integer t, + @ApiParam("最大迭代次数") String userId, + @RequestParam(required = false) @RequestPart MultipartFile file) { // //验证文件类型 // if (!file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf(".")).equals(".xls") && !file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf(".")).equals(".xlsx")) { @@ -131,13 +144,37 @@ public class StuDigitalMarketingModelController { + + @ApiOperation("聚类分析--分析结果") + @PostMapping("/clusterAnalysisResult") + @AnonymousAccess + public ResultEntity clusterAnalysisResult(@ApiParam("簇数") Integer k, + @ApiParam("最大迭代次数") Integer t, + @ApiParam("最大迭代次数") String userId, + @RequestParam(required = false) @RequestPart MultipartFile file) { + //初始化数据 + ArrayList> arrayLists = readTable(file); + + + //随机选择 k 个数据点作为初始聚类中心 + ArrayList> centerList = randomList(k, arrayLists); + + //开始迭代 + ArrayList>> kmeans = kmeans(k, t, centerList); + + + return new ResultEntity(HttpStatus.OK); + } + + + @ApiOperation("关联规则挖掘") @PostMapping("/apriori") @AnonymousAccess public ResultEntity apriori(@ApiParam("最小支持度阀值") double support, @ApiParam("最小置信度") double confidence, @ApiParam("用户ID") String userId, - @RequestParam(required = false) @RequestPart MultipartFile file) { + @RequestParam(required = false) @RequestPart MultipartFile file) throws IOException { //初始化事务数据库、项目集、候选集再进行剪枝 @@ -154,21 +191,39 @@ public class StuDigitalMarketingModelController { } - @ApiOperation("回归分析") + @ApiOperation("回归分析--线性回归") @PostMapping("/regressionAnalysis") @AnonymousAccess - public ResultEntity regressionAnalysis(@ApiParam("自变量X") double[] x, - @ApiParam("因变量Y") double[] y, - @ApiParam("用户ID") String userId) { + public ResultEntity regressionAnalysis(@RequestBody RegressionAnalysisDTO analysisDTO) { + double[][] x = analysisDTO.getX(); + double[] y = analysisDTO.getY(); + // 创建线性回归模型 LinearRegression regression=new LinearRegression(); regression.fit(x,y); + return new ResultEntity(HttpStatus.OK,"成功",regression); } +// @ApiOperation("回归分析--逻辑回归") +// @PostMapping("/logistic") +// @AnonymousAccess +// public ResultEntity logistic(@RequestBody LogisticDTO logisticDTO) { +// double[] x = logisticDTO.getX(); +// double[] y = logisticDTO.getY(); +// +// // 创建线性回归模型 +// LogisticRegression regression=new LogisticRegression(); +// regression.fit(x,y); +// +// +// return new ResultEntity(HttpStatus.OK,"成功",regression); +// } + + @ApiOperation("情感分析/文本挖掘") @PostMapping("/emotionalAnalysis") @AnonymousAccess diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/DataPreprocessingDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/DataPreprocessingDTO.java new file mode 100644 index 0000000..7c2f926 --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/entity/dto/DataPreprocessingDTO.java @@ -0,0 +1,19 @@ +package com.sztzjy.marketing.entity.dto; + +import lombok.Data; + +import java.util.List; +import java.util.Map; + +/** + * @author tz + * @date 2024/6/25 14:21 + */ +@Data +public class DataPreprocessingDTO { + private String userId; + + private String method; + + private List> mapList; +} diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/LogisticDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/LogisticDTO.java new file mode 100644 index 0000000..b32dad1 --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/entity/dto/LogisticDTO.java @@ -0,0 +1,21 @@ +package com.sztzjy.marketing.entity.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +/** + * @author tz + * @date 2024/7/23 16:52 + */ +@Data +public class LogisticDTO { + + @ApiModelProperty("自变量x") + private double[] x; + + @ApiModelProperty("因变量y") + private double[] y; + + @ApiModelProperty("用户ID") + private String userId; +} diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/RegressionAnalysisDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/RegressionAnalysisDTO.java new file mode 100644 index 0000000..c0610d0 --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/entity/dto/RegressionAnalysisDTO.java @@ -0,0 +1,22 @@ +package com.sztzjy.marketing.entity.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.util.List; + +/** + * @author tz + * @date 2024/7/19 13:56 + */ +@Data +public class RegressionAnalysisDTO { + @ApiModelProperty("自变量x") + private double[][] x; + + @ApiModelProperty("因变量y") + private double[] y; + + @ApiModelProperty("用户ID") + private String userId; +} diff --git a/src/main/java/com/sztzjy/marketing/service/StuDigitalMarketingModelService.java b/src/main/java/com/sztzjy/marketing/service/StuDigitalMarketingModelService.java index 47a7fe9..e387be9 100644 --- a/src/main/java/com/sztzjy/marketing/service/StuDigitalMarketingModelService.java +++ b/src/main/java/com/sztzjy/marketing/service/StuDigitalMarketingModelService.java @@ -22,4 +22,7 @@ public interface StuDigitalMarketingModelService { List viewMetrics(String userId, String tableName); ResultEntity viewAnalyzeData(String userId, String tableName, List fieldList); + + ResultEntity dataPreprocessing(String userId, String method, List> mapList); + } diff --git a/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java b/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java index 675f7a2..731e8c2 100644 --- a/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java +++ b/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java @@ -1,5 +1,6 @@ package com.sztzjy.marketing.service.impl; + import com.hankcs.hanlp.HanLP; import com.hankcs.hanlp.seg.common.Term; import com.sztzjy.marketing.config.Constant; @@ -13,16 +14,16 @@ import com.sztzjy.marketing.service.StuDigitalMarketingModelService; import com.sztzjy.marketing.util.ResultEntity; import com.sztzjy.marketing.util.algorithm.DescriptiveStatisticsUtil; import okhttp3.*; -import org.checkerframework.checker.units.qual.C; -import org.geolatte.geom.M; -import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpStatus; import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.IOException; +import java.time.LocalDateTime; import java.util.*; +import static com.sztzjy.marketing.util.algorithm.BaiDuZhiNengYun.getAccessToken; + /** * @author tz * @date 2024/6/14 11:05 @@ -69,18 +70,18 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM @Override public ResultEntity emotionalAnalysis(String userId, String modelType, String content) throws IOException { - String commentUrl = "https://aip.baidubce.com/rpc/2.0/nlp/v2/comment_tag?access_token=24.88968c130db3ca9f266907b5004bec8f.2592000.1721270821.282335-83957582&charset=UTF-8"; - String emoUrl="https://aip.baidubce.com/rpc/2.0/nlp/v1/sentiment_classify?access_token=24.88968c130db3ca9f266907b5004bec8f.2592000.1721270821.282335-83957582"; - MediaType mediaType = MediaType.parse("application/json"); if(modelType.equals(Constant.COMMENT_EXTRACTION)){ //评论观点抽取 + String commentUrl = "https://aip.baidubce.com/rpc/2.0/nlp/v2/comment_tag?charset=UTF-8&access_token="+ getAccessToken(); RequestBody body = RequestBody.create(mediaType, "{\"text\":\""+content+"\",\"type\":8}"); Request request = new Request.Builder() .url(commentUrl) + .addHeader("Content-Type", "application/json") + .addHeader("Accept", "application/json") .method("POST", body) .build(); Response response = HTTP_CLIENT.newCall(request).execute(); @@ -90,9 +91,10 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM } } - if(modelType.equals(Constant.EMOTIONAL_TENDENCIES)){ //情感倾向分析 + if (modelType.equals(Constant.EMOTIONAL_TENDENCIES)) { //情感倾向分析 + String emoUrl = "https://aip.baidubce.com/rpc/2.0/nlp/v1/sentiment_classify?charset=UTF-8&access_token=" + getAccessToken(); - RequestBody body = RequestBody.create(mediaType, "{\"text\":\""+content+"\"}"); + RequestBody body = RequestBody.create(mediaType, "{\"text\":\"" + content + "\"}"); Request request = new Request.Builder() .url(emoUrl) @@ -101,9 +103,10 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM Response response = HTTP_CLIENT.newCall(request).execute(); if (response.body() != null) { - return new ResultEntity(HttpStatus.OK,"成功",response.body().string()); + return new ResultEntity(HttpStatus.OK, "成功", response.body().string()); } + } @@ -229,7 +232,6 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM } statisticsDTO.setStatistics(statistics); - dtoList.add(statisticsDTO); } @@ -280,9 +282,106 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM List> attributes = tableNameMapper.selectByFields(fieldList,table); + //将日期类型转为string返回 + for (Map row : attributes) { + for (Map.Entry entry : row.entrySet()) { + if (entry.getValue() instanceof LocalDateTime) { + entry.setValue(entry.getValue().toString()); + } + } + } + -// List createTime = attributes.get("create_time"); return new ResultEntity(HttpStatus.OK,attributes); } + + @Override + public ResultEntity dataPreprocessing(String userId, String method, List> mapList) { + //数据去重 + Set set = new HashSet<>(); + List> deduplicatedDataList = new ArrayList<>(); + + for (Map map : mapList) { + String mapString = map.toString(); + if (!set.contains(mapString)) { + set.add(mapString); + deduplicatedDataList.add(map); + } + } + + + //判断缺失值处理方式 + if(method.equals("剔除数据")){ + + + deduplicatedDataList.removeIf(record -> record.containsValue(null)); + + + } + + + + if(method.equals("均值代替")){ + +// // 均值代替处理 +// for (Map record : deduplicatedDataList) { +// for (Map.Entry entry : record.entrySet()) { +// if (entry.getValue() == null) { +// double mean = calculateMean(deduplicatedDataList, entry.getKey()); +// entry.setValue(mean); +// } +// } +// } + + //众数代替 + for (Map record : deduplicatedDataList) { + for (Map.Entry entry : record.entrySet()) { + if (entry.getValue() == null) { + String mode = calculateMode(deduplicatedDataList, entry.getKey()); + entry.setValue(mode); + } + } + } + } + return new ResultEntity(HttpStatus.OK,"成功",deduplicatedDataList); + } + + + //均值代替计算方式 + private static double calculateMean(List> dataList, String key) { + double sum = 0; + int count = 0; + for (Map record : dataList) { + Object value = record.get(key); + if (value != null && value instanceof Number) { + sum += ((Number) value).doubleValue(); + count++; + } + } + return count > 0 ? (sum / count) : 0; + } + + + //众数代替计算方式 + private static String calculateMode(List> dataList, String key) { + Map counts = new HashMap<>(); + for (Map record : dataList) { + Object value = record.get(key); + if (value != null && value instanceof String) { + String stringValue = (String) value; + counts.put(stringValue, counts.getOrDefault(stringValue, 0) + 1); + } + } + int maxCount = 0; + String mode = null; + for (Map.Entry entry : counts.entrySet()) { + if (entry.getValue() > maxCount) { + maxCount = entry.getValue(); + mode = entry.getKey(); + } + } + return mode != null ? mode : ""; // 如果存在众数则返回,否则返回空字符串 + } + } diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/Apriori.java b/src/main/java/com/sztzjy/marketing/util/algorithm/Apriori.java index 14ced86..96db4dd 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/Apriori.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/Apriori.java @@ -1,10 +1,21 @@ package com.sztzjy.marketing.util.algorithm; +import com.alibaba.excel.EasyExcel; +import com.alibaba.excel.context.AnalysisContext; +import com.alibaba.excel.event.AnalysisEventListener; import com.sztzjy.marketing.entity.dto.AssociationRulesDTO; +import org.apache.poi.ss.usermodel.Cell; +import org.apache.poi.ss.usermodel.Row; +import org.apache.poi.ss.usermodel.Sheet; +import org.apache.poi.ss.usermodel.Workbook; +import org.apache.poi.xssf.usermodel.XSSFWorkbook; import org.springframework.web.multipart.MultipartFile; import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.util.*; public class Apriori { @@ -19,30 +30,26 @@ public class Apriori { // 用于存取候选集每次计算结果(即存放所有的频繁项集L),最后计算关联规则,就不用再次遍历事务数据库。 static HashMap, Integer> L_ALL = new HashMap, Integer>(); //从文件中读取内容,返回事务集 - public static ArrayList> readTable(MultipartFile file){ - ArrayList> t = new ArrayList>(); - ArrayList t1 = null; - try { - InputStreamReader isr = new InputStreamReader(file.getInputStream()); - BufferedReader bf = new BufferedReader(isr); - String str = null; - while((str = bf.readLine()) != null) { - t1 = new ArrayList(); - String[] str1 = str.split(","); - for(int i = 1; i < str1.length ; i++) { - t1.add(str1[i]); + public static ArrayList> readTable(MultipartFile file) { + ArrayList> t = new ArrayList<>(); + try (InputStream inputStream = file.getInputStream(); + Workbook workbook = new XSSFWorkbook(inputStream)) { + Sheet sheet = workbook.getSheetAt(0); + for (Row row : sheet) { + ArrayList rowData = new ArrayList<>(); + for (Cell cell : row) { + rowData.add(cell.getStringCellValue()); } - t.add(t1); + t.add(rowData); } - bf.close(); - isr.close(); - } catch (Exception e) { + } catch (IOException e) { e.printStackTrace(); System.out.println("文件读取失败!"); } - System.out.println("\nD:"+t); + System.out.println("\nD:" + t); return t; } + //剪枝步:从候选集C中删除小于最小支持度的,并放入频繁集L中 public static void pruning(HashMap, Integer> C,HashMap, Integer> L,double min_support) { L.clear(); @@ -63,7 +70,7 @@ public class Apriori { /** * 初始化事务数据库、项目集、候选集 */ - public static void init(MultipartFile file,double min_support) { + public static void init(MultipartFile file,double min_support) throws IOException { // //将文件中的数据放入集合D中 D = readTable(file); // 扫描事务数据库。生成项目集,支持度=该元素在事务数据库出现的次数/事务数据库的事务数 diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/BaiDuZhiNengYun.java b/src/main/java/com/sztzjy/marketing/util/algorithm/BaiDuZhiNengYun.java new file mode 100644 index 0000000..ddf3a5f --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/BaiDuZhiNengYun.java @@ -0,0 +1,39 @@ +package com.sztzjy.marketing.util.algorithm; + +import okhttp3.*; +import org.json.JSONObject; + +import java.io.IOException; + +/** + * @author tz + * @date 2024/7/19 15:25 + */ +public class BaiDuZhiNengYun { + public static final String API_KEY = "12hSC65yXQdFndZridHTaHtd"; + public static final String SECRET_KEY = "0SegCsSEahKBYB4z6ZLELebLMRt37gib"; + + static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build(); + + /** + * 从用户的AK,SK生成鉴权签名(Access Token) + * + * @return 鉴权签名(Access Token) + * @throws IOException IO异常 + */ + public static String getAccessToken() throws IOException { + MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded"); + RequestBody body = RequestBody.create(mediaType, "grant_type=client_credentials&client_id=" + API_KEY + + "&client_secret=" + SECRET_KEY); + Request request = new Request.Builder() + .url("https://aip.baidubce.com/oauth/2.0/token") + .method("POST", body) + .addHeader("Content-Type", "application/x-www-form-urlencoded") + .build(); + Response response = HTTP_CLIENT.newCall(request).execute(); + if (response.body() != null) { + return new JSONObject(response.body().string()).getString("access_token"); + } + return null; + } +} diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/DataDeal.java b/src/main/java/com/sztzjy/marketing/util/algorithm/DataDeal.java index af0660a..3d97320 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/DataDeal.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/DataDeal.java @@ -5,7 +5,7 @@ import java.io.*; public class DataDeal { static double min_support = 4; //最小支持度 static double min_confident = 0.7; //最小置信度 - //文件路径,我的是在当前项目下,大家换成自己的文件路径就是了 + //文件路径 static String filePath = "D:\\home\\marketing\\1.txt"; static ArrayList> D = new ArrayList>();//事务数据库D static HashMap, Integer> C = new HashMap, Integer>();//项目集C diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/KMeans.java b/src/main/java/com/sztzjy/marketing/util/algorithm/KMeans.java index f9ea1d4..756a3f0 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/KMeans.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/KMeans.java @@ -171,4 +171,6 @@ public class KMeans { } return points; } + + } \ No newline at end of file diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/KMeansResult.java b/src/main/java/com/sztzjy/marketing/util/algorithm/KMeansResult.java new file mode 100644 index 0000000..958747d --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/KMeansResult.java @@ -0,0 +1,243 @@ +package com.sztzjy.marketing.util.algorithm; + +import org.springframework.web.multipart.MultipartFile; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; + +/** + * @author tz + * @date 2024/7/2 17:37 + */ +public class KMeansResult { + + // 记录迭代的次数 + static int count = 1; + // 文件所在路径 +// static String filePath = "D:\\home\\marketing\\Segmentation_dataset(1).csv"; + // 储存从文件中读取的数据 + static ArrayList> table = new ArrayList>(); + // 储存分类一的结果 + static ArrayList> alist = new ArrayList>(); + // 储存分类二的结果 + static ArrayList> blist = new ArrayList>(); + // 储存分类三的结果 + static ArrayList> clist = new ArrayList>(); + // 记录初始随机产生的3个聚类中心 + static ArrayList> randomList = new ArrayList>(); + + // 读取文件中的数据,储存到集合中 + public static ArrayList> readTable(MultipartFile file){ + table.clear(); + ArrayList d = null; + + try (BufferedReader br = new BufferedReader(new InputStreamReader(file.getInputStream()))) { + + + // 跳过标题行 + br.readLine(); + + String str = null; + while((str = br.readLine()) != null) { + d = new ArrayList(); + String[] str1 = str.split(","); + for(int i = 0; i < str1.length ; i++) { + d.add(Float.parseFloat(str1[i])); + } + table.add(d); + } +// System.out.println(table); + + } catch (Exception e) { + e.printStackTrace(); + System.out.println("文件不存在!"); + } + return table; + } + + // 随机产生k个初始聚类中心 + public static ArrayList> randomList(Integer k,ArrayList> randomList) { + ArrayList> centerList=new ArrayList<>(); + + int[] list = new int[k]; + + // Generate k distinct random numbers + do { + for (int i = 0; i < k; i++) { + list[i] = (int)(Math.random() * 30); + } + } while (hasDuplicates(list)); + + + // For testing purposes, we can take the first k elements from the data set as initial cluster centers + for (int i = 0; i < k; i++) { + centerList.add(table.get(list[i])); + } + + return centerList; + } + + // helper方法来检查数组中是否包含重复项 + private static boolean hasDuplicates(int[] arr) { + Set set = new HashSet<>(); + for (int val : arr) { + if (!set.add(val)) { + return true; + } + } + return false; + } + + //比较两个数的大小,并返回其中较小的数 + public static double minNumber(double x, double y) { + if(x < y) { + return x; + } + return y; + } + + // 计算各个数据到中心点的距离,然后分成k类 + public static void eudistance(ArrayList> list, int k) { + ArrayList>> clusters = new ArrayList<>(k); + for (int i = 0; i < k; i++) { + clusters.add(new ArrayList<>()); + } + + for (int i = 0; i < table.size(); i++) { + double[] distances = new double[k]; + for (int j = 0; j < k; j++) { + distances[j] = calculateEuclideanDistance(table.get(i), list.get(j)); + } + int closestCluster = findClosestCluster(distances); + clusters.get(closestCluster).add(table.get(i)); + } + + System.out.println("第" + count + "次迭代:"); + for (int i = 0; i < k; i++) { + System.out.println("质心 " + i + ": " + clusters.get(i)); + } + System.out.println(); + count++; + } + + //计算两个数据点之间的欧几里得距离 + private static double calculateEuclideanDistance(ArrayList point1, ArrayList point2) { + double sum = 0; + for (int i = 1; i < point1.size(); i++) { + sum += Math.pow(point1.get(i) - point2.get(i), 2); + } + return Math.sqrt(sum); + } + + //根据距离数组找到最近的簇的索引 + private static int findClosestCluster(double[] distances) { + double minDistance = Double.MAX_VALUE; + int closestCluster = 0; + for (int i = 0; i < distances.length; i++) { + if (distances[i] < minDistance) { + minDistance = distances[i]; + closestCluster = i; + } + } + return closestCluster; + } + + //最终聚类结果 + public static ArrayList>> kmeans(int k,Integer t,ArrayList> centroids) { + ArrayList>> clusters = new ArrayList<>(k); + for (int i = 0; i < k; i++) { + clusters.add(new ArrayList<>()); + } + + boolean converged = false; + int iterations = 0; + while (!converged && iterations < t) { + for (int i = 0; i < table.size(); i++) { + ArrayList point = table.get(i); + int closestCluster = findClosestCluster(point, centroids); + clusters.get(closestCluster).add(point); + } + + ArrayList> newCentroids = new ArrayList<>(k); + for (int i = 0; i < k; i++) { + newCentroids.add(calculateCentroid(clusters.get(i))); + } + + converged = checkConvergence(centroids, newCentroids); + + if(!converged){ + centroids = newCentroids; + for (int i = 0; i < k; i++) { + clusters.get(i).clear(); + } + iterations++; + continue; + }else { + break; + } + } + + System.out.println("最终聚类结果:"); + for (int i = 0; i < k; i++) { + System.out.println("簇 " + i + ": " + clusters.get(i)); + } + return clusters; + } + + //计算一个簇的质心 + private static ArrayList calculateCentroid(ArrayList> cluster) { + float sum1 = 0, sum2 = 0, sum3 = 0, sum4 = 0; + for (ArrayList point : cluster) { + sum1 += point.get(1); + sum2 += point.get(2); + sum3 += point.get(3); + sum4 += point.get(4); + } + float size = cluster.size(); + ArrayList centroid = new ArrayList<>(); + centroid.add(0f); + centroid.add(sum1 / size); + centroid.add(sum2 / size); + centroid.add(sum3 / size); + centroid.add(sum4 / size); + return centroid; + } + + //检查是否收敛,如果没有收敛,则用新质心替换旧质心,并清空所有簇,准备进行下一轮迭代 + private static boolean checkConvergence(ArrayList> oldCentroids, ArrayList> newCentroids) { + for (int i = 0; i < oldCentroids.size(); i++) { + if (!oldCentroids.get(i).equals(newCentroids.get(i))) { + return false; + } + } + return true; + } + + private static int findClosestCluster(ArrayList point, ArrayList> centroids) { + double minDistance = Double.MAX_VALUE; + int closestCluster = 0; + for (int i = 0; i < centroids.size(); i++) { + double distance = calculateEuclideanDistance(point, centroids.get(i)); + if (distance < minDistance) { + minDistance = distance; + closestCluster = i; + } + } + return closestCluster; + } + + + public static void main(String[] args) { +// ArrayList> rlist = new ArrayList>(); +// readTable(filePath); +// rlist = randomList(); +// eudistance(rlist); +// kmeans(); + } + +} diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/LinearRegression.java b/src/main/java/com/sztzjy/marketing/util/algorithm/LinearRegression.java index 01f683d..4948516 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/LinearRegression.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/LinearRegression.java @@ -1,31 +1,61 @@ package com.sztzjy.marketing.util.algorithm; +import java.text.DecimalFormat; + public class LinearRegression { private double intercept; - private double slope; + private double[] slopes; // 线性回归的训练方法 - public void fit(double[] x, double[] y) { - int n = x.length; - double sumX = 0.0, sumY = 0.0, sumXY = 0.0, sumX2 = 0.0; + public void fit(double[][] x, double[] y) { + int n = y.length; + int m = x[0].length; + double[] sumX = new double[m]; + double sumY = 0.0; + double[] sumXY = new double[m]; + double[] sumX2 = new double[m]; for (int i = 0; i < n; i++) { - sumX += x[i]; + for (int j = 0; j < m; j++) { + sumX[j] += x[i][j]; + sumXY[j] += x[i][j] * y[i]; + sumX2[j] += x[i][j] * x[i][j]; + } sumY += y[i]; - sumXY += x[i] * y[i]; - sumX2 += x[i] * x[i]; } - double xMean = sumX / n; + double[] xMean = new double[m]; + for (int j = 0; j < m; j++) { + xMean[j] = sumX[j] / n; + } double yMean = sumY / n; + + DecimalFormat df = new DecimalFormat("#.00"); + + + this.slopes = new double[m]; + for (int j = 0; j < m; j++) { + this.slopes[j] = Double.parseDouble(df.format((sumXY[j] - n * xMean[j] * yMean) / (sumX2[j] - n * xMean[j] * xMean[j]))); + } + + this.intercept = Double.parseDouble(df.format(yMean - dotProduct(this.slopes, xMean))); + } - this.slope = (sumXY - n * xMean * yMean) / (sumX2 - n * xMean * xMean); - this.intercept = yMean - this.slope * xMean; + // 辅助方法:计算向量的点积 + private double dotProduct(double[] a, double[] b) { + double result = 0.0; + for (int i = 0; i < a.length; i++) { + result += a[i] * b[i]; + } + return result; } // 预测方法 - public double predict(double x) { - return this.intercept + this.slope * x; + public double predict(double[] x) { + if (x.length != this.slopes.length) { + throw new IllegalArgumentException("输入的x维度与模型不匹配"); + } + return this.intercept + dotProduct(this.slopes, x); } // 获取截距 @@ -34,14 +64,14 @@ public class LinearRegression { } // 获取斜率 - public double getSlope() { - return this.slope; + public double[] getSlopes() { + return this.slopes; } public static void main(String[] args) { // 示例数据 - double[] x = {1, 2, 3, 4, 5}; - double[] y = {2, 3, 5, 6, 8}; + double[][] x = {{1, 19}, {1, 21}, {2, 20}, {2, 23}, {2, 31}}; + double[] y = {15, 15, 16, 16, 17}; // 创建线性回归模型 LinearRegression lr = new LinearRegression(); @@ -49,11 +79,11 @@ public class LinearRegression { // 输出模型参数 System.out.println("截距: " + lr.getIntercept()); - System.out.println("斜率: " + lr.getSlope()); + System.out.println("斜率: " + java.util.Arrays.toString(lr.getSlopes())); // 预测新数据 - double newX = 6; + double[] newX = {2, 45}; double predictedY = lr.predict(newX); - System.out.println("输入x: " + newX + "预测y=" + predictedY); + System.out.println("输入x: " + java.util.Arrays.toString(newX) + "预测y=" + predictedY); } } \ No newline at end of file diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/LogisticRegression.java b/src/main/java/com/sztzjy/marketing/util/algorithm/LogisticRegression.java index 428a2ec..363d34b 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/LogisticRegression.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/LogisticRegression.java @@ -1,196 +1,71 @@ package com.sztzjy.marketing.util.algorithm; +import java.util.Arrays; -import org.ujmp.core.DenseMatrix; -import org.ujmp.core.Matrix; - -/** -* 参考:https://blog.csdn.net/weixin_45040801/article/details/102542209 - * https://blog.csdn.net/ccblogger/article/details/81739200 -* */ public class LogisticRegression { + private double[] weights; + private double bias; - /** - * 训练逻辑回归模型 - * - * @param data 输入的特征数据 - * @param classValues 目标类别值 - * @return 训练得到的模型参数 - */ - public static double[] train(double[][] data, double[] classValues) { - - if (data != null && classValues != null && data.length == classValues.length) { -// 期望矩阵 -// Matrix matrWeights = DenseMatrix.Factory.zeros(data[0].length + 1, 1); - Matrix matrWeights = DenseMatrix.Factory.zeros(data[0].length, 1); - System.out.println("data[0].length + 1========"+data[0].length + 1); -// 数据矩阵 -// Matrix matrData = DenseMatrix.Factory.zeros(data.length, data[0].length + 1); - Matrix matrData = DenseMatrix.Factory.zeros(data.length, data[0].length); -// 标志矩阵 - Matrix matrLable = DenseMatrix.Factory.zeros(data.length, 1); -// 训练速率矩阵 -// Matrix matrRate = DenseMatrix.Factory.zeros(data[0].length + 1,data.length); -// 统计difference的总体失误的辅助矩阵 - Matrix matrDiffUtil = DenseMatrix.Factory.zeros(data[0].length,data.length); - for(int i=0;i