diff --git a/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java b/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java index a14a448..4b1555a 100644 --- a/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java +++ b/src/main/java/com/sztzjy/marketing/controller/stu/StuDigitalMarketingModelController.java @@ -30,7 +30,7 @@ import java.math.BigDecimal; import java.text.DecimalFormat; import java.util.*; -import static com.sztzjy.marketing.util.algorithm.KMeans.readIrisData; +import static com.sztzjy.marketing.util.algorithm.KMeans.*; import static com.sztzjy.marketing.util.algorithm.KMeansResult.*; /** @@ -62,8 +62,9 @@ public class StuDigitalMarketingModelController { @PostMapping("/viewMetrics") @AnonymousAccess public ResultEntity viewMetrics(@ApiParam("用户ID") String userId, - @ApiParam("表格名") String tableName) { - List list=modelService.viewMetrics(userId,tableName); + @ApiParam("表格名") String tableName, + @ApiParam("算法名称") String algorithmName) { + List list=modelService.viewMetrics(userId,tableName,algorithmName); return new ResultEntity(HttpStatus.OK,"成功",list); } @@ -112,7 +113,9 @@ public class StuDigitalMarketingModelController { Integer k = clusterScatterPlotDTO.getK(); List> deduplicatedDataList = clusterScatterPlotDTO.getDeduplicatedDataList(); - List irisData = readIrisData(deduplicatedDataList); + List strings = convertToDelimitedStringList(deduplicatedDataList); + + List irisData = readIrisData(strings); //获取数据集 KMeans kMeans = new KMeans(k, t, irisData); @@ -148,12 +151,16 @@ public class StuDigitalMarketingModelController { @ApiOperation("聚类分析--分析结果") @PostMapping("/clusterAnalysisResult") @AnonymousAccess - public ResultEntity clusterAnalysisResult(@ApiParam("簇数") Integer k, - @ApiParam("最大迭代次数") Integer t, - @ApiParam("最大迭代次数") String userId, - @RequestParam(required = false) @RequestPart MultipartFile file) { + public ResultEntity clusterAnalysisResult(@RequestBody ClusterScatterPlotDTO clusterScatterPlotDTO) { + + Integer t = clusterScatterPlotDTO.getT(); + Integer k = clusterScatterPlotDTO.getK(); + List> deduplicatedDataList = clusterScatterPlotDTO.getDeduplicatedDataList(); + + List strings = convertToDelimitedStringList(deduplicatedDataList); + //初始化数据 - ArrayList> arrayLists = readTable(file); + ArrayList> arrayLists = readTable(strings); //随机选择 k 个数据点作为初始聚类中心 @@ -163,7 +170,7 @@ public class StuDigitalMarketingModelController { ArrayList>> kmeans = kmeans(k, t, centerList); - return new ResultEntity(HttpStatus.OK); + return new ResultEntity(HttpStatus.OK,kmeans); } @@ -171,20 +178,23 @@ public class StuDigitalMarketingModelController { @ApiOperation("关联规则挖掘") @PostMapping("/apriori") @AnonymousAccess - public ResultEntity apriori(@ApiParam("最小支持度阀值") double support, - @ApiParam("最小置信度") double confidence, - @ApiParam("用户ID") String userId, - @RequestParam(required = false) @RequestPart MultipartFile file) throws IOException { + public ResultEntity apriori(@RequestBody AprioriDTO aprioriDTO) throws IOException { + + double support = aprioriDTO.getSupport(); + double confidence = aprioriDTO.getConfidence(); + String userId = aprioriDTO.getUserId(); + List> deduplicatedDataList = aprioriDTO.getDeduplicatedDataList(); + List strings = convertToDelimitedStringList(deduplicatedDataList); //初始化事务数据库、项目集、候选集再进行剪枝 - Apriori.init(file,support); + Apriori.init(strings,support); //迭代求出最终的候选频繁集 Apriori.iteration(Apriori.C,Apriori.L,support); //根据最终的关联集,根据公式计算出各个关联事件 - List connection = Apriori.connection(confidence); + List connection = Apriori.connection(confidence); return new ResultEntity(HttpStatus.OK,"成功",connection); @@ -202,26 +212,38 @@ public class StuDigitalMarketingModelController { LinearRegression regression=new LinearRegression(); regression.fit(x,y); + double[] slopes = regression.getSlopes(); + + RAnalysisDTO rAnalysisDTO=new RAnalysisDTO(); + rAnalysisDTO.setIntercept(regression.getIntercept()); + for (int i = 0; i < slopes.length; i++) { + rAnalysisDTO.setSlopes(slopes[i]); + } - return new ResultEntity(HttpStatus.OK,"成功",regression); + return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO); } -// @ApiOperation("回归分析--逻辑回归") -// @PostMapping("/logistic") -// @AnonymousAccess -// public ResultEntity logistic(@RequestBody LogisticDTO logisticDTO) { -// double[] x = logisticDTO.getX(); -// double[] y = logisticDTO.getY(); -// -// // 创建线性回归模型 -// LogisticRegression regression=new LogisticRegression(); -// regression.fit(x,y); -// -// -// return new ResultEntity(HttpStatus.OK,"成功",regression); -// } + @ApiOperation("回归分析--逻辑回归") + @PostMapping("/logistic") + @AnonymousAccess + public ResultEntity logistic(@RequestBody LogisticDTO logisticDTO) { + double[][] x = logisticDTO.getX(); + double[] y = logisticDTO.getY(); + + // 创建一个逻辑回归模型实例,0.1学习率 + LogisticRegression model = new LogisticRegression(0.1); + model.train(x,y,1000); //训练模型 + + DecimalFormat df = new DecimalFormat("#.0"); + RAnalysisDTO rAnalysisDTO=new RAnalysisDTO(); + + rAnalysisDTO.setIntercept(Double.parseDouble(df.format(model.getIntercept()))); + rAnalysisDTO.setSlopes(Double.parseDouble(df.format(model.getCoefficient()))); + + return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO); + } @ApiOperation("情感分析/文本挖掘") @@ -237,7 +259,7 @@ public class StuDigitalMarketingModelController { @ApiOperation("批量导入") @PostMapping("/batchImport") @AnonymousAccess - public ResultEntity batchImport(@RequestPart MultipartFile file) { + public ResultEntity batchImport(@RequestParam(required = false) @RequestPart MultipartFile file) { //验证文件类型 if (!file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf(".")).equals(".xls") && !file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf(".")).equals(".xlsx")) { @@ -262,10 +284,10 @@ public class StuDigitalMarketingModelController { StuSpendingLevel stuSpendingLevel=new StuSpendingLevel(); stuSpendingLevel.setId(Integer.valueOf((String) list.get(0))); - stuSpendingLevel.setGender(Integer.valueOf((String) list.get(0))); - stuSpendingLevel.setAge(Integer.valueOf((String) list.get(0))); - stuSpendingLevel.setAnnualIncome(Integer.valueOf((String) list.get(0))); - stuSpendingLevel.setSpendingScore(Integer.valueOf((String) list.get(0))); + stuSpendingLevel.setGender(Integer.valueOf((String) list.get(1))); + stuSpendingLevel.setAge(Integer.valueOf((String) list.get(2))); + stuSpendingLevel.setAnnualIncome(Integer.valueOf((String) list.get(3))); + stuSpendingLevel.setSpendingScore(Integer.valueOf((String) list.get(4))); diff --git a/src/main/java/com/sztzjy/marketing/entity/StuUserLoginActive.java b/src/main/java/com/sztzjy/marketing/entity/StuUserLoginActive.java index b29dbf0..b804953 100644 --- a/src/main/java/com/sztzjy/marketing/entity/StuUserLoginActive.java +++ b/src/main/java/com/sztzjy/marketing/entity/StuUserLoginActive.java @@ -2,7 +2,6 @@ package com.sztzjy.marketing.entity; import java.util.Date; -import io.swagger.annotations.ApiModel; import io.swagger.annotations.ApiModelProperty; /** * 用户登录活跃表 @@ -10,7 +9,6 @@ import io.swagger.annotations.ApiModelProperty; * @author whb * stu_user_login_active */ -@ApiModel("用户登录活跃表") public class StuUserLoginActive { @ApiModelProperty("id") private Integer id; @@ -48,8 +46,8 @@ public class StuUserLoginActive { @ApiModelProperty("登录频次") private Integer loginFrequency; - @ApiModelProperty("登录时长") - private String loginDuration; + @ApiModelProperty("登录时长(分钟)") + private Double loginDuration; @ApiModelProperty("创建时间") private Date createTime; @@ -153,12 +151,12 @@ public class StuUserLoginActive { this.loginFrequency = loginFrequency; } - public String getLoginDuration() { + public Double getLoginDuration() { return loginDuration; } - public void setLoginDuration(String loginDuration) { - this.loginDuration = loginDuration == null ? null : loginDuration.trim(); + public void setLoginDuration(Double loginDuration) { + this.loginDuration = loginDuration; } public Date getCreateTime() { diff --git a/src/main/java/com/sztzjy/marketing/entity/StuUserLoginActiveExample.java b/src/main/java/com/sztzjy/marketing/entity/StuUserLoginActiveExample.java index 2fa6933..18d09a4 100644 --- a/src/main/java/com/sztzjy/marketing/entity/StuUserLoginActiveExample.java +++ b/src/main/java/com/sztzjy/marketing/entity/StuUserLoginActiveExample.java @@ -925,62 +925,52 @@ public class StuUserLoginActiveExample { return (Criteria) this; } - public Criteria andLoginDurationEqualTo(String value) { + public Criteria andLoginDurationEqualTo(Double value) { addCriterion("login_duration =", value, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationNotEqualTo(String value) { + public Criteria andLoginDurationNotEqualTo(Double value) { addCriterion("login_duration <>", value, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationGreaterThan(String value) { + public Criteria andLoginDurationGreaterThan(Double value) { addCriterion("login_duration >", value, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationGreaterThanOrEqualTo(String value) { + public Criteria andLoginDurationGreaterThanOrEqualTo(Double value) { addCriterion("login_duration >=", value, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationLessThan(String value) { + public Criteria andLoginDurationLessThan(Double value) { addCriterion("login_duration <", value, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationLessThanOrEqualTo(String value) { + public Criteria andLoginDurationLessThanOrEqualTo(Double value) { addCriterion("login_duration <=", value, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationLike(String value) { - addCriterion("login_duration like", value, "loginDuration"); - return (Criteria) this; - } - - public Criteria andLoginDurationNotLike(String value) { - addCriterion("login_duration not like", value, "loginDuration"); - return (Criteria) this; - } - - public Criteria andLoginDurationIn(List values) { + public Criteria andLoginDurationIn(List values) { addCriterion("login_duration in", values, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationNotIn(List values) { + public Criteria andLoginDurationNotIn(List values) { addCriterion("login_duration not in", values, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationBetween(String value1, String value2) { + public Criteria andLoginDurationBetween(Double value1, Double value2) { addCriterion("login_duration between", value1, value2, "loginDuration"); return (Criteria) this; } - public Criteria andLoginDurationNotBetween(String value1, String value2) { + public Criteria andLoginDurationNotBetween(Double value1, Double value2) { addCriterion("login_duration not between", value1, value2, "loginDuration"); return (Criteria) this; } diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/AprioriDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/AprioriDTO.java new file mode 100644 index 0000000..e1c0d1a --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/entity/dto/AprioriDTO.java @@ -0,0 +1,23 @@ +package com.sztzjy.marketing.entity.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.util.List; +import java.util.Map; + +/** + * @author tz + * @date 2024/8/17 14:16 + */ +@Data +public class AprioriDTO { + @ApiModelProperty("最小支持度阀值") + private double support; + @ApiModelProperty("最小置信度") + private double confidence; + @ApiModelProperty("用户ID") + private String userId; + @ApiModelProperty("数据集") + private List> deduplicatedDataList; +} diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/DescriptiveStatistics.java b/src/main/java/com/sztzjy/marketing/entity/dto/DescriptiveStatistics.java index 56342a3..d0379ae 100644 --- a/src/main/java/com/sztzjy/marketing/entity/dto/DescriptiveStatistics.java +++ b/src/main/java/com/sztzjy/marketing/entity/dto/DescriptiveStatistics.java @@ -18,7 +18,7 @@ public class DescriptiveStatistics { private double median; @ApiModelProperty("众数") - private List mode; + private double mode; @ApiModelProperty("标准差") private double standardDeviation; diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/LogisticDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/LogisticDTO.java index b32dad1..b5cc2f5 100644 --- a/src/main/java/com/sztzjy/marketing/entity/dto/LogisticDTO.java +++ b/src/main/java/com/sztzjy/marketing/entity/dto/LogisticDTO.java @@ -11,7 +11,7 @@ import lombok.Data; public class LogisticDTO { @ApiModelProperty("自变量x") - private double[] x; + private double[][] x; @ApiModelProperty("因变量y") private double[] y; diff --git a/src/main/java/com/sztzjy/marketing/entity/dto/RAnalysisDTO.java b/src/main/java/com/sztzjy/marketing/entity/dto/RAnalysisDTO.java new file mode 100644 index 0000000..d2cae68 --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/entity/dto/RAnalysisDTO.java @@ -0,0 +1,13 @@ +package com.sztzjy.marketing.entity.dto; + +import lombok.Data; + +/** + * @author tz + * @date 2024/8/17 16:30 + */ +@Data +public class RAnalysisDTO { + private double intercept; + private double slopes; +} diff --git a/src/main/java/com/sztzjy/marketing/mapper/StuUserBehaviorMapper.java b/src/main/java/com/sztzjy/marketing/mapper/StuUserBehaviorMapper.java index 1568b0a..9b908ac 100644 --- a/src/main/java/com/sztzjy/marketing/mapper/StuUserBehaviorMapper.java +++ b/src/main/java/com/sztzjy/marketing/mapper/StuUserBehaviorMapper.java @@ -3,10 +3,8 @@ package com.sztzjy.marketing.mapper; import com.sztzjy.marketing.entity.StuUserBehavior; import com.sztzjy.marketing.entity.StuUserBehaviorExample; import java.util.List; - -import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Param; -@Mapper + public interface StuUserBehaviorMapper { long countByExample(StuUserBehaviorExample example); @@ -29,6 +27,4 @@ public interface StuUserBehaviorMapper { int updateByPrimaryKeySelective(StuUserBehavior record); int updateByPrimaryKey(StuUserBehavior record); - - List selectByFields(List fieldList, String table); } \ No newline at end of file diff --git a/src/main/java/com/sztzjy/marketing/mapper/StuUserLoginActiveMapper.java b/src/main/java/com/sztzjy/marketing/mapper/StuUserLoginActiveMapper.java index 61fcd79..76b0ecc 100644 --- a/src/main/java/com/sztzjy/marketing/mapper/StuUserLoginActiveMapper.java +++ b/src/main/java/com/sztzjy/marketing/mapper/StuUserLoginActiveMapper.java @@ -27,6 +27,4 @@ public interface StuUserLoginActiveMapper { int updateByPrimaryKeySelective(StuUserLoginActive record); int updateByPrimaryKey(StuUserLoginActive record); - - List selectByFields(List fieldList, String table); } \ No newline at end of file diff --git a/src/main/java/com/sztzjy/marketing/service/StuDigitalMarketingModelService.java b/src/main/java/com/sztzjy/marketing/service/StuDigitalMarketingModelService.java index 8a5357b..6b286c3 100644 --- a/src/main/java/com/sztzjy/marketing/service/StuDigitalMarketingModelService.java +++ b/src/main/java/com/sztzjy/marketing/service/StuDigitalMarketingModelService.java @@ -20,7 +20,7 @@ public interface StuDigitalMarketingModelService { ResultEntity descriptiveStatistics(Map> map, List statistic, String userId); - List viewMetrics(String userId, String tableName); + List viewMetrics(String userId, String tableName,String algorithmName); ResultEntity viewAnalyzeData(AnalyzeDataDTO analyzeDataDTO); diff --git a/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java b/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java index 28ed716..da51f31 100644 --- a/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java +++ b/src/main/java/com/sztzjy/marketing/service/impl/StuDigitalMarketingModelServiceImpl.java @@ -19,6 +19,7 @@ import org.springframework.stereotype.Service; import javax.annotation.Resource; import java.io.IOException; +import java.text.DecimalFormat; import java.time.Instant; import java.time.LocalDateTime; import java.time.ZoneId; @@ -190,13 +191,14 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM } if(statistic.get(i).equals("众数")){ - List mode = statisticsUtil.getMode(map.get(key)); + Double mode = statisticsUtil.getMode(map.get(key)); statistics.setMode(mode); } if(statistic.get(i).equals("标准差")){ double standardDeviation = statisticsUtil.getStandardDeviation(map.get(key)); - statistics.setStandardDeviation(standardDeviation); + DecimalFormat df = new DecimalFormat("#.0"); + statistics.setStandardDeviation(Double.parseDouble(df.format(standardDeviation))); } if(statistic.get(i).equals("方差")){ @@ -211,12 +213,24 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM } if(statistic.get(i).equals("峰度")){ double kurtosis = statisticsUtil.getKurtosis(map.get(key)); - statistics.setKurtosis(kurtosis); + + if(Double.isNaN(kurtosis)){ + statistics.setKurtosis(0); + }else { + statistics.setKurtosis(kurtosis); + + } } if(statistic.get(i).equals("偏度")){ double skewness = statisticsUtil.getSkewness(map.get(key)); - statistics.setSkewness(skewness); + + if(Double.isNaN(skewness)){ + statistics.setSkewness(0); + }else { + statistics.setSkewness(skewness); + } + } if(statistic.get(i).equals("最大值")){ @@ -248,7 +262,7 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM } @Override - public List viewMetrics(String userId, String tableName) { + public List viewMetrics(String userId, String tableName,String algorithmName) { List list=new ArrayList<>(); if(tableName.equals(Constant.YHSXB)){ @@ -258,14 +272,33 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM list.add("role_age"); } if(tableName.equals(Constant.YHDLHYB)){ - list=indicatorsMapper.getYHDLHYB(); +// list=indicatorsMapper.getYHDLHYB(); + list.add("id"); + list.add("login_frequency"); + list.add("login_duration"); } if(tableName.equals(Constant.YHXFNLB)){ - list=indicatorsMapper.getYHXFNLB(); +// list=indicatorsMapper.getYHXFNLB(); + if(algorithmName.equals("关联规则挖掘")){ + list.add("consumer_goods"); + } else { + list.add("id"); + list.add("consumer_amount"); + list.add("consumer_number"); + list.add("consumer_past_seven_days_amount"); + list.add("consumer_past_seven_days_number"); + } + } if(tableName.equals(Constant.YHPLB) || tableName.equals(Constant.YHXWB)){ - list=indicatorsMapper.getYHPLB(); +// list=indicatorsMapper.getYHPLB(); + if(algorithmName.equals("关联规则挖掘")){ + list.add("goods_name"); + }else { + list.add("id"); + list.add("user_behavior_type"); + } } return list; } @@ -285,7 +318,7 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM } if(analyzeDataDTO.getTableName().equals(Constant.YHXFNLB)){ //查询用户消费能力表 - table="stu_spending_level"; + table="stu_user_consumption_ability"; } if(analyzeDataDTO.getTableName().equals(Constant.YHPLB) || analyzeDataDTO.getTableName().equals(Constant.YHXWB)){ //查询用户评论或行为表 diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/Apriori.java b/src/main/java/com/sztzjy/marketing/util/algorithm/Apriori.java index 96db4dd..dc93ee8 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/Apriori.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/Apriori.java @@ -30,25 +30,39 @@ public class Apriori { // 用于存取候选集每次计算结果(即存放所有的频繁项集L),最后计算关联规则,就不用再次遍历事务数据库。 static HashMap, Integer> L_ALL = new HashMap, Integer>(); //从文件中读取内容,返回事务集 - public static ArrayList> readTable(MultipartFile file) { - ArrayList> t = new ArrayList<>(); - try (InputStream inputStream = file.getInputStream(); - Workbook workbook = new XSSFWorkbook(inputStream)) { - Sheet sheet = workbook.getSheetAt(0); - for (Row row : sheet) { - ArrayList rowData = new ArrayList<>(); - for (Cell cell : row) { - rowData.add(cell.getStringCellValue()); - } - t.add(rowData); + + public static ArrayList> readTable(List inputList) { + ArrayList> result = new ArrayList<>(); + for (String line : inputList) { + String[] values = line.split(","); + ArrayList rowData = new ArrayList<>(); + for (String value : values) { + rowData.add(value); } - } catch (IOException e) { - e.printStackTrace(); - System.out.println("文件读取失败!"); + result.add(rowData); } - System.out.println("\nD:" + t); - return t; + System.out.println("\nD:" + result); + return result; } +// public static ArrayList> readTable(MultipartFile file) { +// ArrayList> t = new ArrayList<>(); +// try (InputStream inputStream = file.getInputStream(); +// Workbook workbook = new XSSFWorkbook(inputStream)) { +// Sheet sheet = workbook.getSheetAt(0); +// for (Row row : sheet) { +// ArrayList rowData = new ArrayList<>(); +// for (Cell cell : row) { +// rowData.add(cell.getStringCellValue()); +// } +// t.add(rowData); +// } +// } catch (IOException e) { +// e.printStackTrace(); +// System.out.println("文件读取失败!"); +// } +// System.out.println("\nD:" + t); +// return t; +// } //剪枝步:从候选集C中删除小于最小支持度的,并放入频繁集L中 public static void pruning(HashMap, Integer> C,HashMap, Integer> L,double min_support) { @@ -70,9 +84,9 @@ public class Apriori { /** * 初始化事务数据库、项目集、候选集 */ - public static void init(MultipartFile file,double min_support) throws IOException { + public static void init( List strings,double min_support) throws IOException { // //将文件中的数据放入集合D中 - D = readTable(file); + D = readTable(strings); // 扫描事务数据库。生成项目集,支持度=该元素在事务数据库出现的次数/事务数据库的事务数 for (int i = 0; i < D.size(); i++) { for (int j = 0; j < D.get(i).size(); j++) { @@ -198,10 +212,12 @@ public class Apriori { /** * 根据最终的关联集,根据公式计算出各个关联事件 */ - public static List connection(double min_confident) { + public static List connection(double min_confident) { List list=new ArrayList<>(); + List lists=new ArrayList<>(); + for (ArrayList key : L_ALL.keySet()) {// 对最终的关联集各个事件进行判断 ArrayList> key_allSubset = getSubset(key); //得到所有频繁集中每个集合的子集 @@ -235,8 +251,14 @@ public class Apriori { list.add(associationRulesDTO); + + System.out.print(item_pre + "==>" + item_post );// 则是一个关联事件 System.out.println("==>" + confident); + + String string=item_pre + "==>" + item_post+"==>" + confident; + + lists.add(string); } } @@ -244,7 +266,7 @@ public class Apriori { } } } - return list; + return lists; } public static void main(String[] args) { diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/Data.java b/src/main/java/com/sztzjy/marketing/util/algorithm/Data.java new file mode 100644 index 0000000..9a2c9b6 --- /dev/null +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/Data.java @@ -0,0 +1,11 @@ +package com.sztzjy.marketing.util.algorithm; + +public class Data { + public double[] features; + public double label; + + public Data(double[] features, double label) { + this.features = features; + this.label = label; + } +} \ No newline at end of file diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/DescriptiveStatisticsUtil.java b/src/main/java/com/sztzjy/marketing/util/algorithm/DescriptiveStatisticsUtil.java index e25e1eb..83f9e7d 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/DescriptiveStatisticsUtil.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/DescriptiveStatisticsUtil.java @@ -32,7 +32,7 @@ public class DescriptiveStatisticsUtil { // 计算众数 - public List getMode(List data) { + public Double getMode(List data) { HashMap map = new HashMap(); double imode = 0; for (int i = 0; i < data.size(); i++) { @@ -57,7 +57,10 @@ public class DescriptiveStatisticsUtil { lst.add(key); } } - return lst; + + Object object = lst.get(0); + Double v = Double.valueOf(object.toString()); + return v; } @@ -90,6 +93,7 @@ public class DescriptiveStatisticsUtil { public double getKurtosis(List data) { DescriptiveStatistics stats = new DescriptiveStatistics(); for (double value : data) { + stats.addValue(value); } return stats.getKurtosis(); @@ -101,7 +105,8 @@ public class DescriptiveStatisticsUtil { for (double value : data) { stats.addValue(value); } - return stats.getSkewness(); + double skewness = stats.getSkewness(); + return skewness; } // 获取最大值 diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/KMeans.java b/src/main/java/com/sztzjy/marketing/util/algorithm/KMeans.java index 804853e..34e54ea 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/KMeans.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/KMeans.java @@ -6,10 +6,9 @@ import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.io.InputStreamReader; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Random; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class KMeans { @@ -134,7 +133,7 @@ public class KMeans { } } -// public static void main(String[] args) { + public static void main(String[] args) { // List irisData = readIrisData("D:\\home\\marketing\\Iris.txt"); // int k = 3; // 假设聚为3类 // int maxIterations = 100; @@ -151,22 +150,35 @@ public class KMeans { // for (int i = 0; i < centroids.size(); i++) { // System.out.println("Centroid for cluster " + i + ": " + centroids.get(i)); // } -// } - public static List readIrisData(List> deduplicatedDataList) { - List points = new ArrayList<>(); - for (Map data : deduplicatedDataList) { - // 假设每个Map都有"x"和"y"键,并且它们的值是Double类型 - // 这里没有错误处理,你可能需要添加一些来确保键存在且值可以转换为Double - Double x = (Double) data.get("x"); - Double y = (Double) data.get("y"); - - if (x != null && y != null) { - points.add(new KMeans.Point(x, y)); - } else { - // 处理缺少x或y的情况,比如记录日志或抛出异常 - System.err.println("数据中缺少x或y: " + data); + } + + public static List convertToDelimitedStringList(List> nestedList) { + List flatList = new ArrayList<>(); + for (Map innerMap : nestedList) { + StringBuilder sb = new StringBuilder(); + for (Object value : innerMap.values()) { + if (sb.length() > 0) { + sb.append(","); + } + sb.append(value); } + flatList.add(sb.toString()); + } + return flatList; + } + + + public static List readIrisData(List data) { + List points = new ArrayList<>(); + for (String line : data) { + String[] values = line.split(","); + + double x = Double.parseDouble(values[0]); + double y = Double.parseDouble(values[1]); + + + points.add(new KMeans.Point(x, y)); } return points; } diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/KMeansResult.java b/src/main/java/com/sztzjy/marketing/util/algorithm/KMeansResult.java index 958747d..fa0f6c9 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/KMeansResult.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/KMeansResult.java @@ -8,6 +8,7 @@ import java.io.FileInputStream; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import java.util.Set; /** @@ -32,33 +33,48 @@ public class KMeansResult { static ArrayList> randomList = new ArrayList>(); // 读取文件中的数据,储存到集合中 - public static ArrayList> readTable(MultipartFile file){ + public static ArrayList> readTable(List data) { table.clear(); - ArrayList d = null; - try (BufferedReader br = new BufferedReader(new InputStreamReader(file.getInputStream()))) { - - // 跳过标题行 - br.readLine(); - - String str = null; - while((str = br.readLine()) != null) { - d = new ArrayList(); - String[] str1 = str.split(","); - for(int i = 0; i < str1.length ; i++) { - d.add(Float.parseFloat(str1[i])); - } - table.add(d); + for (String line : data) { + ArrayList d = new ArrayList(); + String[] str1 = line.split(","); + for (String str2 : str1) { + d.add(Float.parseFloat(str2)); } -// System.out.println(table); - - } catch (Exception e) { - e.printStackTrace(); - System.out.println("文件不存在!"); + table.add(d); } + return table; } +// public static ArrayList> readTable(MultipartFile file){ +// table.clear(); +// ArrayList d = null; +// +// try (BufferedReader br = new BufferedReader(new InputStreamReader(file.getInputStream()))) { +// +// +// // 跳过标题行 +// br.readLine(); +// +// String str = null; +// while((str = br.readLine()) != null) { +// d = new ArrayList(); +// String[] str1 = str.split(","); +// for(int i = 0; i < str1.length ; i++) { +// d.add(Float.parseFloat(str1[i])); +// } +// table.add(d); +// } +//// System.out.println(table); +// +// } catch (Exception e) { +// e.printStackTrace(); +// System.out.println("文件不存在!"); +// } +// return table; +// } // 随机产生k个初始聚类中心 public static ArrayList> randomList(Integer k,ArrayList> randomList) { @@ -66,15 +82,16 @@ public class KMeansResult { int[] list = new int[k]; - // Generate k distinct random numbers + // 生成k个不同的随机数 + int size = table.size(); //生成的随机数不能大于table中元素的个数 do { for (int i = 0; i < k; i++) { - list[i] = (int)(Math.random() * 30); + list[i] = (int)(Math.random() * size); } } while (hasDuplicates(list)); - // For testing purposes, we can take the first k elements from the data set as initial cluster centers + // 将数据集中的前k个元素作为初始集群中心 for (int i = 0; i < k; i++) { centerList.add(table.get(list[i])); } @@ -193,11 +210,24 @@ public class KMeansResult { private static ArrayList calculateCentroid(ArrayList> cluster) { float sum1 = 0, sum2 = 0, sum3 = 0, sum4 = 0; for (ArrayList point : cluster) { - sum1 += point.get(1); - sum2 += point.get(2); - sum3 += point.get(3); - sum4 += point.get(4); + for (int i = 0; i < point.size(); i++) { + if(i==0){ + sum1 += point.get(0); + } + if(i==1){ + sum2 += point.get(1); + } + if(i==2){ + sum3 += point.get(2); + } + if(i==3){ + sum4 += point.get(3); + } + + } + } + float size = cluster.size(); ArrayList centroid = new ArrayList<>(); centroid.add(0f); diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/LinearRegression.java b/src/main/java/com/sztzjy/marketing/util/algorithm/LinearRegression.java index 4948516..a370d24 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/LinearRegression.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/LinearRegression.java @@ -1,6 +1,8 @@ package com.sztzjy.marketing.util.algorithm; import java.text.DecimalFormat; +import java.util.ArrayList; +import java.util.List; public class LinearRegression { private double intercept; @@ -76,14 +78,21 @@ public class LinearRegression { // 创建线性回归模型 LinearRegression lr = new LinearRegression(); lr.fit(x, y); - + + DecimalFormat df = new DecimalFormat("#.0"); + + List list=new ArrayList<>(); + for (int i = 0; i < lr.getSlopes().length; i++) { + String format = df.format(lr.getSlopes()[i]); + list.add(format); + } // 输出模型参数 - System.out.println("截距: " + lr.getIntercept()); - System.out.println("斜率: " + java.util.Arrays.toString(lr.getSlopes())); + System.out.println("常数项值: " + df.format(lr.getIntercept())); + System.out.println("第一个特征系数值: " + list.get(0)); // 预测新数据 - double[] newX = {2, 45}; + double[] newX = {2, 88}; double predictedY = lr.predict(newX); - System.out.println("输入x: " + java.util.Arrays.toString(newX) + "预测y=" + predictedY); + System.out.println("输入x: " + java.util.Arrays.toString(newX) + "预测y=" + df.format(predictedY)); } } \ No newline at end of file diff --git a/src/main/java/com/sztzjy/marketing/util/algorithm/LogisticRegression.java b/src/main/java/com/sztzjy/marketing/util/algorithm/LogisticRegression.java index 363d34b..d0917af 100644 --- a/src/main/java/com/sztzjy/marketing/util/algorithm/LogisticRegression.java +++ b/src/main/java/com/sztzjy/marketing/util/algorithm/LogisticRegression.java @@ -1,71 +1,121 @@ package com.sztzjy.marketing.util.algorithm; -import java.util.Arrays; +import com.sztzjy.marketing.util.BigDecimalUtils; + +import java.text.DecimalFormat; public class LogisticRegression { - private double[] weights; - private double bias; + private double theta0; // 截距项 + private double theta1; // 斜率 + private double learningRate; // 学习率 + + private int currentEpoch; // 声明currentEpoch变量 - public LogisticRegression(int numFeatures) { - weights = new double[numFeatures]; - bias = 0; + public LogisticRegression(double learningRate) { + this.learningRate = learningRate; } - public void fit(double[][] X, double[] y, double learningRate, int numIterations) { - int numSamples = X.length; - int numFeatures = X[0].length; - for (int iter = 0; iter < numIterations; iter++) { - double[] gradientWeights = new double[numFeatures]; - double gradientBias = 0; - for (int i = 0; i < numSamples; i++) { - double dotProduct = dotProduct(X[i], weights) + bias; - double prediction = sigmoid(dotProduct); - double error = y[i] - prediction; - for (int j = 0; j < numFeatures; j++) { - gradientWeights[j] += error * X[i][j]; - } - gradientBias += error; - } + public LogisticRegression(double learningRate, double initialTheta0, double initialTheta1) { + this.learningRate = learningRate; + this.theta0 = initialTheta0; + this.theta1 = initialTheta1; + this.currentEpoch = 0; // 在构造函数中初始化currentEpoch + } - for (int j = 0; j < numFeatures; j++) { - weights[j] += learningRate * gradientWeights[j] / numSamples; - } - bias += learningRate * gradientBias / numSamples; - } + // 添加一个方法来获取截距项 + public double getIntercept() { + return theta0; } - public double predict(double[] x) { - double dotProduct = dotProduct(x, weights) + bias; - return sigmoid(dotProduct); + // 添加一个方法来获取第一个特征的系数 + public double getCoefficient() { + return theta1; } - private double sigmoid(double x) { - return 1 / (1 + Math.exp(-x)); + // Sigmoid函数 + private double sigmoid(double z) { + return 1 / (1 + Math.exp(-z)); } - private double dotProduct(double[] x, double[] y) { - double sum = 0; - for (int i = 0; i < x.length; i++) { - sum += x[i] * y[i]; + // 预测函数 + public double predict(double x) { + BigDecimalUtils bigDecimalUtils=new BigDecimalUtils(); + + Double mul = bigDecimalUtils.mul(this.theta0, x); + + return bigDecimalUtils.add(this.theta1, mul); + + } + + private double dotProduct(double[] a, double[] b) { + double result = 0.0; + for (int i = 0; i < a.length; i++) { + result += a[i] * b[i]; } - return sum; + return result; + } + + // 梯度下降更新参数 + public void updateParameters(double x, double y) { + double h = predict(x); + double error = y - h; + theta1 += learningRate * error * h * (1 - h) * x; + theta0 += learningRate * error * h * (1 - h); + } + private double exponentialDecayLearningRate(double currentEpoch, int totalEpochs) { + return learningRate * Math.exp(-(currentEpoch / totalEpochs)); + } + + // 训练模型 + public void train(double[][] data, double[] labels, int epochs) { + double learningRateCurrentEpoch = exponentialDecayLearningRate(currentEpoch, epochs); + for (int epoch = 0; epoch < epochs; epoch++) { + for (int i = 0; i < data.length; i++) { + double x = data[i][0]; + double y = labels[i]; + double h = sigmoid(theta0 + theta1 * x); + double error = y - h; + theta1 += learningRateCurrentEpoch * error * h * (1 - h) * x; + theta0 += learningRateCurrentEpoch * error * h * (1 - h); + } + // 可以添加打印语句查看训练情况 +// System.out.println("Epoch: " + (epoch + 1) + ", Loss: " + calculateLoss(data, labels)); + } + } + public double calculateLoss(double[][] data, double[] labels) { + double loss = 0; + for (int i = 0; i < data.length; i++) { + double x = data[i][0]; + double y = labels[i]; + double h = sigmoid(theta0 + theta1 * x); + loss += -y * Math.log(h) - (1 - y) * Math.log(1 - h); + } + return loss / data.length; } public static void main(String[] args) { - double[][] X = {{1, 1}, {1, 2}, {2, 1}, {2, 3}}; - double[] y = {0, 0, 0, 1}; + LogisticRegression model = new LogisticRegression(0.1); // 创建一个逻辑回归模型实例 + + double[][] data = {{1}, {2}, {3}, {4}, {5}}; + double[] labels = {0, 0, 1, 1, 1}; // 假设这是一个二分类问题 + +// double[][] data = {{1, 19}, {1, 21}, {2, 20}, {2, 23}, {2, 31}}; +// double[] labels = {15, 15, 16, 16, 17}; + + model.train(data, labels, 1000); // 训练模型 - LogisticRegression model = new LogisticRegression(2); - model.fit(X, y, 0.01, 10000); + DecimalFormat df = new DecimalFormat("#.0"); - System.out.println("Weights: " + Arrays.toString(model.weights)); - System.out.println("Bias: " + model.bias); + // 获取并打印截距项和系数 + double intercept = model.getIntercept(); + double coefficient = model.getCoefficient(); + System.out.println("常数项值: " + df.format(intercept)); + System.out.println("系数值 " + df.format(coefficient)); - double[] newSample = {1, 2}; - double prediction = model.predict(newSample); - System.out.println("Prediction for " + Arrays.toString(newSample) + ": " + prediction); + double prediction = model.predict(2); + System.out.println("x=3.5的预测: " + df.format(prediction)); } } \ No newline at end of file diff --git a/src/main/java/com/sztzjy/marketing/util/excel/ImportExcelUtil.java b/src/main/java/com/sztzjy/marketing/util/excel/ImportExcelUtil.java index 9270e72..95ec016 100644 --- a/src/main/java/com/sztzjy/marketing/util/excel/ImportExcelUtil.java +++ b/src/main/java/com/sztzjy/marketing/util/excel/ImportExcelUtil.java @@ -6,6 +6,7 @@ import org.apache.poi.ss.usermodel.*; import org.apache.poi.xssf.usermodel.XSSFWorkbook; import org.springframework.stereotype.Component; +import java.io.IOException; import java.io.InputStream; import java.math.BigDecimal; import java.text.DecimalFormat; diff --git a/src/main/resources/mappers/StuSpendingLevelMapper.xml b/src/main/resources/mappers/StuSpendingLevelMapper.xml index 6a3b65b..388134f 100644 --- a/src/main/resources/mappers/StuSpendingLevelMapper.xml +++ b/src/main/resources/mappers/StuSpendingLevelMapper.xml @@ -209,10 +209,10 @@ where id = #{id,jdbcType=INTEGER} - INSERT INTO tch_start_course_name_list (id, gender, age, annual_income, spending_score) + INSERT INTO stu_spending_level (id, gender, age, annual_income, spending_score) VALUES - (#{stuSpendingLevel.id}, #{stuSpendingLevel.gender}, #{stuSpendingLevel.age}, #{stuSpendingLevel.annual_income}, #{stuSpendingLevel.spending_score}) + (#{stuSpendingLevel.id}, #{stuSpendingLevel.gender}, #{stuSpendingLevel.age}, #{stuSpendingLevel.annualIncome}, #{stuSpendingLevel.spendingScore}) \ No newline at end of file diff --git a/src/main/resources/mappers/StuUserBehaviorMapper.xml b/src/main/resources/mappers/StuUserBehaviorMapper.xml index 96a6c43..70541be 100644 --- a/src/main/resources/mappers/StuUserBehaviorMapper.xml +++ b/src/main/resources/mappers/StuUserBehaviorMapper.xml @@ -235,8 +235,7 @@ - - + update stu_user_behavior @@ -384,11 +383,4 @@ update_time = #{updateTime,jdbcType=TIMESTAMP} where id = #{id,jdbcType=INTEGER} - \ No newline at end of file diff --git a/src/main/resources/mappers/StuUserLoginActiveMapper.xml b/src/main/resources/mappers/StuUserLoginActiveMapper.xml index 6c9580f..7e8c225 100644 --- a/src/main/resources/mappers/StuUserLoginActiveMapper.xml +++ b/src/main/resources/mappers/StuUserLoginActiveMapper.xml @@ -14,7 +14,7 @@ - + @@ -122,7 +122,7 @@ #{userName,jdbcType=VARCHAR}, #{studentId,jdbcType=VARCHAR}, #{stuClass,jdbcType=VARCHAR}, #{major,jdbcType=VARCHAR}, #{school,jdbcType=VARCHAR}, #{roleName,jdbcType=VARCHAR}, #{frequentLoginLocation,jdbcType=VARCHAR}, #{lastLoginDate,jdbcType=TIMESTAMP}, - #{loginFrequency,jdbcType=INTEGER}, #{loginDuration,jdbcType=VARCHAR}, #{createTime,jdbcType=TIMESTAMP}, + #{loginFrequency,jdbcType=INTEGER}, #{loginDuration,jdbcType=DOUBLE}, #{createTime,jdbcType=TIMESTAMP}, #{updateTime,jdbcType=TIMESTAMP}) @@ -212,7 +212,7 @@ #{loginFrequency,jdbcType=INTEGER}, - #{loginDuration,jdbcType=VARCHAR}, + #{loginDuration,jdbcType=DOUBLE}, #{createTime,jdbcType=TIMESTAMP}, @@ -228,8 +228,7 @@ - - + update stu_user_login_active @@ -269,7 +268,7 @@ login_frequency = #{record.loginFrequency,jdbcType=INTEGER}, - login_duration = #{record.loginDuration,jdbcType=VARCHAR}, + login_duration = #{record.loginDuration,jdbcType=DOUBLE}, create_time = #{record.createTime,jdbcType=TIMESTAMP}, @@ -296,7 +295,7 @@ frequent_login_location = #{record.frequentLoginLocation,jdbcType=VARCHAR}, last_login_date = #{record.lastLoginDate,jdbcType=TIMESTAMP}, login_frequency = #{record.loginFrequency,jdbcType=INTEGER}, - login_duration = #{record.loginDuration,jdbcType=VARCHAR}, + login_duration = #{record.loginDuration,jdbcType=DOUBLE}, create_time = #{record.createTime,jdbcType=TIMESTAMP}, update_time = #{record.updateTime,jdbcType=TIMESTAMP} @@ -340,7 +339,7 @@ login_frequency = #{loginFrequency,jdbcType=INTEGER}, - login_duration = #{loginDuration,jdbcType=VARCHAR}, + login_duration = #{loginDuration,jdbcType=DOUBLE}, create_time = #{createTime,jdbcType=TIMESTAMP}, @@ -364,16 +363,9 @@ frequent_login_location = #{frequentLoginLocation,jdbcType=VARCHAR}, last_login_date = #{lastLoginDate,jdbcType=TIMESTAMP}, login_frequency = #{loginFrequency,jdbcType=INTEGER}, - login_duration = #{loginDuration,jdbcType=VARCHAR}, + login_duration = #{loginDuration,jdbcType=DOUBLE}, create_time = #{createTime,jdbcType=TIMESTAMP}, update_time = #{updateTime,jdbcType=TIMESTAMP} where id = #{id,jdbcType=INTEGER} - \ No newline at end of file