whb 7 months ago
commit 32776fcaea

@ -30,7 +30,7 @@ import java.math.BigDecimal;
import java.text.DecimalFormat;
import java.util.*;
import static com.sztzjy.marketing.util.algorithm.KMeans.readIrisData;
import static com.sztzjy.marketing.util.algorithm.KMeans.*;
import static com.sztzjy.marketing.util.algorithm.KMeansResult.*;
/**
@ -62,8 +62,9 @@ public class StuDigitalMarketingModelController {
@PostMapping("/viewMetrics")
@AnonymousAccess
public ResultEntity viewMetrics(@ApiParam("用户ID") String userId,
@ApiParam("表格名") String tableName) {
List<String> list=modelService.viewMetrics(userId,tableName);
@ApiParam("表格名") String tableName,
@ApiParam("算法名称") String algorithmName) {
List<String> list=modelService.viewMetrics(userId,tableName,algorithmName);
return new ResultEntity(HttpStatus.OK,"成功",list);
}
@ -112,7 +113,9 @@ public class StuDigitalMarketingModelController {
Integer k = clusterScatterPlotDTO.getK();
List<Map<String, Object>> deduplicatedDataList = clusterScatterPlotDTO.getDeduplicatedDataList();
List<KMeans.Point> irisData = readIrisData(deduplicatedDataList);
List<String> strings = convertToDelimitedStringList(deduplicatedDataList);
List<KMeans.Point> irisData = readIrisData(strings);
//获取数据集
KMeans kMeans = new KMeans(k, t, irisData);
@ -148,12 +151,16 @@ public class StuDigitalMarketingModelController {
@ApiOperation("聚类分析--分析结果")
@PostMapping("/clusterAnalysisResult")
@AnonymousAccess
public ResultEntity clusterAnalysisResult(@ApiParam("簇数") Integer k,
@ApiParam("最大迭代次数") Integer t,
@ApiParam("最大迭代次数") String userId,
@RequestParam(required = false) @RequestPart MultipartFile file) {
public ResultEntity clusterAnalysisResult(@RequestBody ClusterScatterPlotDTO clusterScatterPlotDTO) {
Integer t = clusterScatterPlotDTO.getT();
Integer k = clusterScatterPlotDTO.getK();
List<Map<String, Object>> deduplicatedDataList = clusterScatterPlotDTO.getDeduplicatedDataList();
List<String> strings = convertToDelimitedStringList(deduplicatedDataList);
//初始化数据
ArrayList<ArrayList<Float>> arrayLists = readTable(file);
ArrayList<ArrayList<Float>> arrayLists = readTable(strings);
//随机选择 k 个数据点作为初始聚类中心
@ -163,7 +170,7 @@ public class StuDigitalMarketingModelController {
ArrayList<ArrayList<ArrayList<Float>>> kmeans = kmeans(k, t, centerList);
return new ResultEntity(HttpStatus.OK);
return new ResultEntity(HttpStatus.OK,kmeans);
}
@ -171,20 +178,23 @@ public class StuDigitalMarketingModelController {
@ApiOperation("关联规则挖掘")
@PostMapping("/apriori")
@AnonymousAccess
public ResultEntity apriori(@ApiParam("最小支持度阀值") double support,
@ApiParam("最小置信度") double confidence,
@ApiParam("用户ID") String userId,
@RequestParam(required = false) @RequestPart MultipartFile file) throws IOException {
public ResultEntity apriori(@RequestBody AprioriDTO aprioriDTO) throws IOException {
double support = aprioriDTO.getSupport();
double confidence = aprioriDTO.getConfidence();
String userId = aprioriDTO.getUserId();
List<Map<String, Object>> deduplicatedDataList = aprioriDTO.getDeduplicatedDataList();
List<String> strings = convertToDelimitedStringList(deduplicatedDataList);
//初始化事务数据库、项目集、候选集再进行剪枝
Apriori.init(file,support);
Apriori.init(strings,support);
//迭代求出最终的候选频繁集
Apriori.iteration(Apriori.C,Apriori.L,support);
//根据最终的关联集,根据公式计算出各个关联事件
List<AssociationRulesDTO> connection = Apriori.connection(confidence);
List<String> connection = Apriori.connection(confidence);
return new ResultEntity(HttpStatus.OK,"成功",connection);
@ -202,26 +212,38 @@ public class StuDigitalMarketingModelController {
LinearRegression regression=new LinearRegression();
regression.fit(x,y);
double[] slopes = regression.getSlopes();
RAnalysisDTO rAnalysisDTO=new RAnalysisDTO();
rAnalysisDTO.setIntercept(regression.getIntercept());
for (int i = 0; i < slopes.length; i++) {
rAnalysisDTO.setSlopes(slopes[i]);
}
return new ResultEntity(HttpStatus.OK,"成功",regression);
return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO);
}
// @ApiOperation("回归分析--逻辑回归")
// @PostMapping("/logistic")
// @AnonymousAccess
// public ResultEntity logistic(@RequestBody LogisticDTO logisticDTO) {
// double[] x = logisticDTO.getX();
// double[] y = logisticDTO.getY();
//
// // 创建线性回归模型
// LogisticRegression regression=new LogisticRegression();
// regression.fit(x,y);
//
//
// return new ResultEntity(HttpStatus.OK,"成功",regression);
// }
@ApiOperation("回归分析--逻辑回归")
@PostMapping("/logistic")
@AnonymousAccess
public ResultEntity logistic(@RequestBody LogisticDTO logisticDTO) {
double[][] x = logisticDTO.getX();
double[] y = logisticDTO.getY();
// 创建一个逻辑回归模型实例,0.1学习率
LogisticRegression model = new LogisticRegression(0.1);
model.train(x,y,1000); //训练模型
DecimalFormat df = new DecimalFormat("#.0");
RAnalysisDTO rAnalysisDTO=new RAnalysisDTO();
rAnalysisDTO.setIntercept(Double.parseDouble(df.format(model.getIntercept())));
rAnalysisDTO.setSlopes(Double.parseDouble(df.format(model.getCoefficient())));
return new ResultEntity(HttpStatus.OK,"成功",rAnalysisDTO);
}
@ApiOperation("情感分析/文本挖掘")
@ -237,7 +259,7 @@ public class StuDigitalMarketingModelController {
@ApiOperation("批量导入")
@PostMapping("/batchImport")
@AnonymousAccess
public ResultEntity batchImport(@RequestPart MultipartFile file) {
public ResultEntity batchImport(@RequestParam(required = false) @RequestPart MultipartFile file) {
//验证文件类型
if (!file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf(".")).equals(".xls") && !file.getOriginalFilename().substring(file.getOriginalFilename().lastIndexOf(".")).equals(".xlsx")) {
@ -262,10 +284,10 @@ public class StuDigitalMarketingModelController {
StuSpendingLevel stuSpendingLevel=new StuSpendingLevel();
stuSpendingLevel.setId(Integer.valueOf((String) list.get(0)));
stuSpendingLevel.setGender(Integer.valueOf((String) list.get(0)));
stuSpendingLevel.setAge(Integer.valueOf((String) list.get(0)));
stuSpendingLevel.setAnnualIncome(Integer.valueOf((String) list.get(0)));
stuSpendingLevel.setSpendingScore(Integer.valueOf((String) list.get(0)));
stuSpendingLevel.setGender(Integer.valueOf((String) list.get(1)));
stuSpendingLevel.setAge(Integer.valueOf((String) list.get(2)));
stuSpendingLevel.setAnnualIncome(Integer.valueOf((String) list.get(3)));
stuSpendingLevel.setSpendingScore(Integer.valueOf((String) list.get(4)));

@ -2,7 +2,6 @@ package com.sztzjy.marketing.entity;
import java.util.Date;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
/**
*
@ -10,7 +9,6 @@ import io.swagger.annotations.ApiModelProperty;
* @author whb
* stu_user_login_active
*/
@ApiModel("用户登录活跃表")
public class StuUserLoginActive {
@ApiModelProperty("id")
private Integer id;
@ -48,8 +46,8 @@ public class StuUserLoginActive {
@ApiModelProperty("登录频次")
private Integer loginFrequency;
@ApiModelProperty("登录时长")
private String loginDuration;
@ApiModelProperty("登录时长(分钟)")
private Double loginDuration;
@ApiModelProperty("创建时间")
private Date createTime;
@ -153,12 +151,12 @@ public class StuUserLoginActive {
this.loginFrequency = loginFrequency;
}
public String getLoginDuration() {
public Double getLoginDuration() {
return loginDuration;
}
public void setLoginDuration(String loginDuration) {
this.loginDuration = loginDuration == null ? null : loginDuration.trim();
public void setLoginDuration(Double loginDuration) {
this.loginDuration = loginDuration;
}
public Date getCreateTime() {

@ -925,62 +925,52 @@ public class StuUserLoginActiveExample {
return (Criteria) this;
}
public Criteria andLoginDurationEqualTo(String value) {
public Criteria andLoginDurationEqualTo(Double value) {
addCriterion("login_duration =", value, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationNotEqualTo(String value) {
public Criteria andLoginDurationNotEqualTo(Double value) {
addCriterion("login_duration <>", value, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationGreaterThan(String value) {
public Criteria andLoginDurationGreaterThan(Double value) {
addCriterion("login_duration >", value, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationGreaterThanOrEqualTo(String value) {
public Criteria andLoginDurationGreaterThanOrEqualTo(Double value) {
addCriterion("login_duration >=", value, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationLessThan(String value) {
public Criteria andLoginDurationLessThan(Double value) {
addCriterion("login_duration <", value, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationLessThanOrEqualTo(String value) {
public Criteria andLoginDurationLessThanOrEqualTo(Double value) {
addCriterion("login_duration <=", value, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationLike(String value) {
addCriterion("login_duration like", value, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationNotLike(String value) {
addCriterion("login_duration not like", value, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationIn(List<String> values) {
public Criteria andLoginDurationIn(List<Double> values) {
addCriterion("login_duration in", values, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationNotIn(List<String> values) {
public Criteria andLoginDurationNotIn(List<Double> values) {
addCriterion("login_duration not in", values, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationBetween(String value1, String value2) {
public Criteria andLoginDurationBetween(Double value1, Double value2) {
addCriterion("login_duration between", value1, value2, "loginDuration");
return (Criteria) this;
}
public Criteria andLoginDurationNotBetween(String value1, String value2) {
public Criteria andLoginDurationNotBetween(Double value1, Double value2) {
addCriterion("login_duration not between", value1, value2, "loginDuration");
return (Criteria) this;
}

@ -0,0 +1,23 @@
package com.sztzjy.marketing.entity.dto;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import java.util.List;
import java.util.Map;
/**
* @author tz
* @date 2024/8/17 14:16
*/
@Data
public class AprioriDTO {
@ApiModelProperty("最小支持度阀值")
private double support;
@ApiModelProperty("最小置信度")
private double confidence;
@ApiModelProperty("用户ID")
private String userId;
@ApiModelProperty("数据集")
private List<Map<String, Object>> deduplicatedDataList;
}

@ -18,7 +18,7 @@ public class DescriptiveStatistics {
private double median;
@ApiModelProperty("众数")
private List<Double> mode;
private double mode;
@ApiModelProperty("标准差")
private double standardDeviation;

@ -11,7 +11,7 @@ import lombok.Data;
public class LogisticDTO {
@ApiModelProperty("自变量x")
private double[] x;
private double[][] x;
@ApiModelProperty("因变量y")
private double[] y;

@ -0,0 +1,13 @@
package com.sztzjy.marketing.entity.dto;
import lombok.Data;
/**
* @author tz
* @date 2024/8/17 16:30
*/
@Data
public class RAnalysisDTO {
private double intercept;
private double slopes;
}

@ -3,10 +3,8 @@ package com.sztzjy.marketing.mapper;
import com.sztzjy.marketing.entity.StuUserBehavior;
import com.sztzjy.marketing.entity.StuUserBehaviorExample;
import java.util.List;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
@Mapper
public interface StuUserBehaviorMapper {
long countByExample(StuUserBehaviorExample example);
@ -29,6 +27,4 @@ public interface StuUserBehaviorMapper {
int updateByPrimaryKeySelective(StuUserBehavior record);
int updateByPrimaryKey(StuUserBehavior record);
List<StuUserBehavior> selectByFields(List<String> fieldList, String table);
}

@ -27,6 +27,4 @@ public interface StuUserLoginActiveMapper {
int updateByPrimaryKeySelective(StuUserLoginActive record);
int updateByPrimaryKey(StuUserLoginActive record);
List<StuUserLoginActive> selectByFields(List<String> fieldList, String table);
}

@ -20,7 +20,7 @@ public interface StuDigitalMarketingModelService {
ResultEntity descriptiveStatistics(Map<String, List<Double>> map, List<String> statistic, String userId);
List<String> viewMetrics(String userId, String tableName);
List<String> viewMetrics(String userId, String tableName,String algorithmName);
ResultEntity viewAnalyzeData(AnalyzeDataDTO analyzeDataDTO);

@ -19,6 +19,7 @@ import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.IOException;
import java.text.DecimalFormat;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
@ -190,13 +191,14 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
}
if(statistic.get(i).equals("众数")){
List<Double> mode = statisticsUtil.getMode(map.get(key));
Double mode = statisticsUtil.getMode(map.get(key));
statistics.setMode(mode);
}
if(statistic.get(i).equals("标准差")){
double standardDeviation = statisticsUtil.getStandardDeviation(map.get(key));
statistics.setStandardDeviation(standardDeviation);
DecimalFormat df = new DecimalFormat("#.0");
statistics.setStandardDeviation(Double.parseDouble(df.format(standardDeviation)));
}
if(statistic.get(i).equals("方差")){
@ -211,12 +213,24 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
}
if(statistic.get(i).equals("峰度")){
double kurtosis = statisticsUtil.getKurtosis(map.get(key));
statistics.setKurtosis(kurtosis);
if(Double.isNaN(kurtosis)){
statistics.setKurtosis(0);
}else {
statistics.setKurtosis(kurtosis);
}
}
if(statistic.get(i).equals("偏度")){
double skewness = statisticsUtil.getSkewness(map.get(key));
statistics.setSkewness(skewness);
if(Double.isNaN(skewness)){
statistics.setSkewness(0);
}else {
statistics.setSkewness(skewness);
}
}
if(statistic.get(i).equals("最大值")){
@ -248,7 +262,7 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
}
@Override
public List<String> viewMetrics(String userId, String tableName) {
public List<String> viewMetrics(String userId, String tableName,String algorithmName) {
List<String> list=new ArrayList<>();
if(tableName.equals(Constant.YHSXB)){
@ -258,14 +272,33 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
list.add("role_age");
}
if(tableName.equals(Constant.YHDLHYB)){
list=indicatorsMapper.getYHDLHYB();
// list=indicatorsMapper.getYHDLHYB();
list.add("id");
list.add("login_frequency");
list.add("login_duration");
}
if(tableName.equals(Constant.YHXFNLB)){
list=indicatorsMapper.getYHXFNLB();
// list=indicatorsMapper.getYHXFNLB();
if(algorithmName.equals("关联规则挖掘")){
list.add("consumer_goods");
} else {
list.add("id");
list.add("consumer_amount");
list.add("consumer_number");
list.add("consumer_past_seven_days_amount");
list.add("consumer_past_seven_days_number");
}
}
if(tableName.equals(Constant.YHPLB) || tableName.equals(Constant.YHXWB)){
list=indicatorsMapper.getYHPLB();
// list=indicatorsMapper.getYHPLB();
if(algorithmName.equals("关联规则挖掘")){
list.add("goods_name");
}else {
list.add("id");
list.add("user_behavior_type");
}
}
return list;
}
@ -285,7 +318,7 @@ public class StuDigitalMarketingModelServiceImpl implements StuDigitalMarketingM
}
if(analyzeDataDTO.getTableName().equals(Constant.YHXFNLB)){ //查询用户消费能力表
table="stu_spending_level";
table="stu_user_consumption_ability";
}
if(analyzeDataDTO.getTableName().equals(Constant.YHPLB) || analyzeDataDTO.getTableName().equals(Constant.YHXWB)){ //查询用户评论或行为表

@ -30,25 +30,39 @@ public class Apriori {
// 用于存取候选集每次计算结果即存放所有的频繁项集L最后计算关联规则就不用再次遍历事务数据库。
static HashMap<ArrayList<String>, Integer> L_ALL = new HashMap<ArrayList<String>, Integer>();
//从文件中读取内容,返回事务集
public static ArrayList<ArrayList<String>> readTable(MultipartFile file) {
ArrayList<ArrayList<String>> t = new ArrayList<>();
try (InputStream inputStream = file.getInputStream();
Workbook workbook = new XSSFWorkbook(inputStream)) {
Sheet sheet = workbook.getSheetAt(0);
for (Row row : sheet) {
ArrayList<String> rowData = new ArrayList<>();
for (Cell cell : row) {
rowData.add(cell.getStringCellValue());
}
t.add(rowData);
public static ArrayList<ArrayList<String>> readTable(List<String> inputList) {
ArrayList<ArrayList<String>> result = new ArrayList<>();
for (String line : inputList) {
String[] values = line.split(",");
ArrayList<String> rowData = new ArrayList<>();
for (String value : values) {
rowData.add(value);
}
} catch (IOException e) {
e.printStackTrace();
System.out.println("文件读取失败!");
result.add(rowData);
}
System.out.println("\nD:" + t);
return t;
System.out.println("\nD:" + result);
return result;
}
// public static ArrayList<ArrayList<String>> readTable(MultipartFile file) {
// ArrayList<ArrayList<String>> t = new ArrayList<>();
// try (InputStream inputStream = file.getInputStream();
// Workbook workbook = new XSSFWorkbook(inputStream)) {
// Sheet sheet = workbook.getSheetAt(0);
// for (Row row : sheet) {
// ArrayList<String> rowData = new ArrayList<>();
// for (Cell cell : row) {
// rowData.add(cell.getStringCellValue());
// }
// t.add(rowData);
// }
// } catch (IOException e) {
// e.printStackTrace();
// System.out.println("文件读取失败!");
// }
// System.out.println("\nD:" + t);
// return t;
// }
//剪枝步从候选集C中删除小于最小支持度的并放入频繁集L中
public static void pruning(HashMap<ArrayList<String>, Integer> C,HashMap<ArrayList<String>, Integer> L,double min_support) {
@ -70,9 +84,9 @@ public class Apriori {
/**
*
*/
public static void init(MultipartFile file,double min_support) throws IOException {
public static void init( List<String> strings,double min_support) throws IOException {
// //将文件中的数据放入集合D中
D = readTable(file);
D = readTable(strings);
// 扫描事务数据库。生成项目集,支持度=该元素在事务数据库出现的次数/事务数据库的事务数
for (int i = 0; i < D.size(); i++) {
for (int j = 0; j < D.get(i).size(); j++) {
@ -198,10 +212,12 @@ public class Apriori {
/**
*
*/
public static List<AssociationRulesDTO> connection(double min_confident) {
public static List<String> connection(double min_confident) {
List<AssociationRulesDTO> list=new ArrayList<>();
List<String> lists=new ArrayList<>();
for (ArrayList<String> key : L_ALL.keySet()) {// 对最终的关联集各个事件进行判断
ArrayList<ArrayList<String>> key_allSubset = getSubset(key);
//得到所有频繁集中每个集合的子集
@ -235,8 +251,14 @@ public class Apriori {
list.add(associationRulesDTO);
System.out.print(item_pre + "==>" + item_post );// 则是一个关联事件
System.out.println("==>" + confident);
String string=item_pre + "==>" + item_post+"==>" + confident;
lists.add(string);
}
}
@ -244,7 +266,7 @@ public class Apriori {
}
}
}
return list;
return lists;
}
public static void main(String[] args) {

@ -0,0 +1,11 @@
package com.sztzjy.marketing.util.algorithm;
public class Data {
public double[] features;
public double label;
public Data(double[] features, double label) {
this.features = features;
this.label = label;
}
}

@ -32,7 +32,7 @@ public class DescriptiveStatisticsUtil {
// 计算众数
public List<Double> getMode(List<Double> data) {
public Double getMode(List<Double> data) {
HashMap map = new HashMap();
double imode = 0;
for (int i = 0; i < data.size(); i++) {
@ -57,7 +57,10 @@ public class DescriptiveStatisticsUtil {
lst.add(key);
}
}
return lst;
Object object = lst.get(0);
Double v = Double.valueOf(object.toString());
return v;
}
@ -90,6 +93,7 @@ public class DescriptiveStatisticsUtil {
public double getKurtosis(List<Double> data) {
DescriptiveStatistics stats = new DescriptiveStatistics();
for (double value : data) {
stats.addValue(value);
}
return stats.getKurtosis();
@ -101,7 +105,8 @@ public class DescriptiveStatisticsUtil {
for (double value : data) {
stats.addValue(value);
}
return stats.getSkewness();
double skewness = stats.getSkewness();
return skewness;
}
// 获取最大值

@ -6,10 +6,9 @@ import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class KMeans {
@ -134,7 +133,7 @@ public class KMeans {
}
}
// public static void main(String[] args) {
public static void main(String[] args) {
// List<KMeans.Point> irisData = readIrisData("D:\\home\\marketing\\Iris.txt");
// int k = 3; // 假设聚为3类
// int maxIterations = 100;
@ -151,22 +150,35 @@ public class KMeans {
// for (int i = 0; i < centroids.size(); i++) {
// System.out.println("Centroid for cluster " + i + ": " + centroids.get(i));
// }
// }
public static List<KMeans.Point> readIrisData(List<Map<String, Object>> deduplicatedDataList) {
List<KMeans.Point> points = new ArrayList<>();
for (Map<String, Object> data : deduplicatedDataList) {
// 假设每个Map都有"x"和"y"键并且它们的值是Double类型
// 这里没有错误处理你可能需要添加一些来确保键存在且值可以转换为Double
Double x = (Double) data.get("x");
Double y = (Double) data.get("y");
if (x != null && y != null) {
points.add(new KMeans.Point(x, y));
} else {
// 处理缺少x或y的情况比如记录日志或抛出异常
System.err.println("数据中缺少x或y: " + data);
}
public static List<String> convertToDelimitedStringList(List<Map<String, Object>> nestedList) {
List<String> flatList = new ArrayList<>();
for (Map<String, Object> innerMap : nestedList) {
StringBuilder sb = new StringBuilder();
for (Object value : innerMap.values()) {
if (sb.length() > 0) {
sb.append(",");
}
sb.append(value);
}
flatList.add(sb.toString());
}
return flatList;
}
public static List<KMeans.Point> readIrisData(List<String> data) {
List<KMeans.Point> points = new ArrayList<>();
for (String line : data) {
String[] values = line.split(",");
double x = Double.parseDouble(values[0]);
double y = Double.parseDouble(values[1]);
points.add(new KMeans.Point(x, y));
}
return points;
}

@ -8,6 +8,7 @@ import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
@ -32,33 +33,48 @@ public class KMeansResult {
static ArrayList<ArrayList<Float>> randomList = new ArrayList<ArrayList<Float>>();
// 读取文件中的数据,储存到集合中
public static ArrayList<ArrayList<Float>> readTable(MultipartFile file){
public static ArrayList<ArrayList<Float>> readTable(List<String> data) {
table.clear();
ArrayList<Float> d = null;
try (BufferedReader br = new BufferedReader(new InputStreamReader(file.getInputStream()))) {
// 跳过标题行
br.readLine();
String str = null;
while((str = br.readLine()) != null) {
d = new ArrayList<Float>();
String[] str1 = str.split(",");
for(int i = 0; i < str1.length ; i++) {
d.add(Float.parseFloat(str1[i]));
}
table.add(d);
for (String line : data) {
ArrayList<Float> d = new ArrayList<Float>();
String[] str1 = line.split(",");
for (String str2 : str1) {
d.add(Float.parseFloat(str2));
}
// System.out.println(table);
} catch (Exception e) {
e.printStackTrace();
System.out.println("文件不存在!");
table.add(d);
}
return table;
}
// public static ArrayList<ArrayList<Float>> readTable(MultipartFile file){
// table.clear();
// ArrayList<Float> d = null;
//
// try (BufferedReader br = new BufferedReader(new InputStreamReader(file.getInputStream()))) {
//
//
// // 跳过标题行
// br.readLine();
//
// String str = null;
// while((str = br.readLine()) != null) {
// d = new ArrayList<Float>();
// String[] str1 = str.split(",");
// for(int i = 0; i < str1.length ; i++) {
// d.add(Float.parseFloat(str1[i]));
// }
// table.add(d);
// }
//// System.out.println(table);
//
// } catch (Exception e) {
// e.printStackTrace();
// System.out.println("文件不存在!");
// }
// return table;
// }
// 随机产生k个初始聚类中心
public static ArrayList<ArrayList<Float>> randomList(Integer k,ArrayList<ArrayList<Float>> randomList) {
@ -66,15 +82,16 @@ public class KMeansResult {
int[] list = new int[k];
// Generate k distinct random numbers
// 生成k个不同的随机数
int size = table.size(); //生成的随机数不能大于table中元素的个数
do {
for (int i = 0; i < k; i++) {
list[i] = (int)(Math.random() * 30);
list[i] = (int)(Math.random() * size);
}
} while (hasDuplicates(list));
// For testing purposes, we can take the first k elements from the data set as initial cluster centers
// 将数据集中的前k个元素作为初始集群中心
for (int i = 0; i < k; i++) {
centerList.add(table.get(list[i]));
}
@ -193,11 +210,24 @@ public class KMeansResult {
private static ArrayList<Float> calculateCentroid(ArrayList<ArrayList<Float>> cluster) {
float sum1 = 0, sum2 = 0, sum3 = 0, sum4 = 0;
for (ArrayList<Float> point : cluster) {
sum1 += point.get(1);
sum2 += point.get(2);
sum3 += point.get(3);
sum4 += point.get(4);
for (int i = 0; i < point.size(); i++) {
if(i==0){
sum1 += point.get(0);
}
if(i==1){
sum2 += point.get(1);
}
if(i==2){
sum3 += point.get(2);
}
if(i==3){
sum4 += point.get(3);
}
}
}
float size = cluster.size();
ArrayList<Float> centroid = new ArrayList<>();
centroid.add(0f);

@ -1,6 +1,8 @@
package com.sztzjy.marketing.util.algorithm;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
public class LinearRegression {
private double intercept;
@ -76,14 +78,21 @@ public class LinearRegression {
// 创建线性回归模型
LinearRegression lr = new LinearRegression();
lr.fit(x, y);
DecimalFormat df = new DecimalFormat("#.0");
List<String> list=new ArrayList<>();
for (int i = 0; i < lr.getSlopes().length; i++) {
String format = df.format(lr.getSlopes()[i]);
list.add(format);
}
// 输出模型参数
System.out.println("截距: " + lr.getIntercept());
System.out.println("斜率: " + java.util.Arrays.toString(lr.getSlopes()));
System.out.println("常数项值: " + df.format(lr.getIntercept()));
System.out.println("第一个特征系数值: " + list.get(0));
// 预测新数据
double[] newX = {2, 45};
double[] newX = {2, 88};
double predictedY = lr.predict(newX);
System.out.println("输入x: " + java.util.Arrays.toString(newX) + "预测y=" + predictedY);
System.out.println("输入x: " + java.util.Arrays.toString(newX) + "预测y=" + df.format(predictedY));
}
}

@ -1,71 +1,121 @@
package com.sztzjy.marketing.util.algorithm;
import java.util.Arrays;
import com.sztzjy.marketing.util.BigDecimalUtils;
import java.text.DecimalFormat;
public class LogisticRegression {
private double[] weights;
private double bias;
private double theta0; // 截距项
private double theta1; // 斜率
private double learningRate; // 学习率
private int currentEpoch; // 声明currentEpoch变量
public LogisticRegression(int numFeatures) {
weights = new double[numFeatures];
bias = 0;
public LogisticRegression(double learningRate) {
this.learningRate = learningRate;
}
public void fit(double[][] X, double[] y, double learningRate, int numIterations) {
int numSamples = X.length;
int numFeatures = X[0].length;
for (int iter = 0; iter < numIterations; iter++) {
double[] gradientWeights = new double[numFeatures];
double gradientBias = 0;
for (int i = 0; i < numSamples; i++) {
double dotProduct = dotProduct(X[i], weights) + bias;
double prediction = sigmoid(dotProduct);
double error = y[i] - prediction;
for (int j = 0; j < numFeatures; j++) {
gradientWeights[j] += error * X[i][j];
}
gradientBias += error;
}
public LogisticRegression(double learningRate, double initialTheta0, double initialTheta1) {
this.learningRate = learningRate;
this.theta0 = initialTheta0;
this.theta1 = initialTheta1;
this.currentEpoch = 0; // 在构造函数中初始化currentEpoch
}
for (int j = 0; j < numFeatures; j++) {
weights[j] += learningRate * gradientWeights[j] / numSamples;
}
bias += learningRate * gradientBias / numSamples;
}
// 添加一个方法来获取截距项
public double getIntercept() {
return theta0;
}
public double predict(double[] x) {
double dotProduct = dotProduct(x, weights) + bias;
return sigmoid(dotProduct);
// 添加一个方法来获取第一个特征的系数
public double getCoefficient() {
return theta1;
}
private double sigmoid(double x) {
return 1 / (1 + Math.exp(-x));
// Sigmoid函数
private double sigmoid(double z) {
return 1 / (1 + Math.exp(-z));
}
private double dotProduct(double[] x, double[] y) {
double sum = 0;
for (int i = 0; i < x.length; i++) {
sum += x[i] * y[i];
// 预测函数
public double predict(double x) {
BigDecimalUtils bigDecimalUtils=new BigDecimalUtils();
Double mul = bigDecimalUtils.mul(this.theta0, x);
return bigDecimalUtils.add(this.theta1, mul);
}
private double dotProduct(double[] a, double[] b) {
double result = 0.0;
for (int i = 0; i < a.length; i++) {
result += a[i] * b[i];
}
return sum;
return result;
}
// 梯度下降更新参数
public void updateParameters(double x, double y) {
double h = predict(x);
double error = y - h;
theta1 += learningRate * error * h * (1 - h) * x;
theta0 += learningRate * error * h * (1 - h);
}
private double exponentialDecayLearningRate(double currentEpoch, int totalEpochs) {
return learningRate * Math.exp(-(currentEpoch / totalEpochs));
}
// 训练模型
public void train(double[][] data, double[] labels, int epochs) {
double learningRateCurrentEpoch = exponentialDecayLearningRate(currentEpoch, epochs);
for (int epoch = 0; epoch < epochs; epoch++) {
for (int i = 0; i < data.length; i++) {
double x = data[i][0];
double y = labels[i];
double h = sigmoid(theta0 + theta1 * x);
double error = y - h;
theta1 += learningRateCurrentEpoch * error * h * (1 - h) * x;
theta0 += learningRateCurrentEpoch * error * h * (1 - h);
}
// 可以添加打印语句查看训练情况
// System.out.println("Epoch: " + (epoch + 1) + ", Loss: " + calculateLoss(data, labels));
}
}
public double calculateLoss(double[][] data, double[] labels) {
double loss = 0;
for (int i = 0; i < data.length; i++) {
double x = data[i][0];
double y = labels[i];
double h = sigmoid(theta0 + theta1 * x);
loss += -y * Math.log(h) - (1 - y) * Math.log(1 - h);
}
return loss / data.length;
}
public static void main(String[] args) {
double[][] X = {{1, 1}, {1, 2}, {2, 1}, {2, 3}};
double[] y = {0, 0, 0, 1};
LogisticRegression model = new LogisticRegression(0.1); // 创建一个逻辑回归模型实例
double[][] data = {{1}, {2}, {3}, {4}, {5}};
double[] labels = {0, 0, 1, 1, 1}; // 假设这是一个二分类问题
// double[][] data = {{1, 19}, {1, 21}, {2, 20}, {2, 23}, {2, 31}};
// double[] labels = {15, 15, 16, 16, 17};
model.train(data, labels, 1000); // 训练模型
LogisticRegression model = new LogisticRegression(2);
model.fit(X, y, 0.01, 10000);
DecimalFormat df = new DecimalFormat("#.0");
System.out.println("Weights: " + Arrays.toString(model.weights));
System.out.println("Bias: " + model.bias);
// 获取并打印截距项和系数
double intercept = model.getIntercept();
double coefficient = model.getCoefficient();
System.out.println("常数项值: " + df.format(intercept));
System.out.println("系数值 " + df.format(coefficient));
double[] newSample = {1, 2};
double prediction = model.predict(newSample);
System.out.println("Prediction for " + Arrays.toString(newSample) + ": " + prediction);
double prediction = model.predict(2);
System.out.println("x=3.5的预测: " + df.format(prediction));
}
}

@ -6,6 +6,7 @@ import org.apache.poi.ss.usermodel.*;
import org.apache.poi.xssf.usermodel.XSSFWorkbook;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigDecimal;
import java.text.DecimalFormat;

@ -209,10 +209,10 @@
where id = #{id,jdbcType=INTEGER}
</update>
<insert id="addList">
INSERT INTO tch_start_course_name_list (id, gender, age, annual_income, spending_score)
INSERT INTO stu_spending_level (id, gender, age, annual_income, spending_score)
VALUES
<foreach collection="stuSpendingLevels" item="stuSpendingLevel" separator=",">
(#{stuSpendingLevel.id}, #{stuSpendingLevel.gender}, #{stuSpendingLevel.age}, #{stuSpendingLevel.annual_income}, #{stuSpendingLevel.spending_score})
(#{stuSpendingLevel.id}, #{stuSpendingLevel.gender}, #{stuSpendingLevel.age}, #{stuSpendingLevel.annualIncome}, #{stuSpendingLevel.spendingScore})
</foreach>
</insert>
</mapper>

@ -235,8 +235,7 @@
<include refid="Example_Where_Clause" />
</if>
</select>
<update id="updateByExampleSelective" parameterType="map">
<update id="updateByExampleSelective" parameterType="map">
update stu_user_behavior
<set>
<if test="record.id != null">
@ -384,11 +383,4 @@
update_time = #{updateTime,jdbcType=TIMESTAMP}
where id = #{id,jdbcType=INTEGER}
</update>
<select id="selectByFields" resultType="com.sztzjy.marketing.entity.StuUserBehavior" resultMap="BaseResultMap">
SELECT
<foreach collection="fieldList" item="field" separator=",">
${field}
</foreach>
FROM ${table}
</select>
</mapper>

@ -14,7 +14,7 @@
<result column="frequent_login_location" jdbcType="VARCHAR" property="frequentLoginLocation" />
<result column="last_login_date" jdbcType="TIMESTAMP" property="lastLoginDate" />
<result column="login_frequency" jdbcType="INTEGER" property="loginFrequency" />
<result column="login_duration" jdbcType="VARCHAR" property="loginDuration" />
<result column="login_duration" jdbcType="DOUBLE" property="loginDuration" />
<result column="create_time" jdbcType="TIMESTAMP" property="createTime" />
<result column="update_time" jdbcType="TIMESTAMP" property="updateTime" />
</resultMap>
@ -122,7 +122,7 @@
#{userName,jdbcType=VARCHAR}, #{studentId,jdbcType=VARCHAR}, #{stuClass,jdbcType=VARCHAR},
#{major,jdbcType=VARCHAR}, #{school,jdbcType=VARCHAR}, #{roleName,jdbcType=VARCHAR},
#{frequentLoginLocation,jdbcType=VARCHAR}, #{lastLoginDate,jdbcType=TIMESTAMP},
#{loginFrequency,jdbcType=INTEGER}, #{loginDuration,jdbcType=VARCHAR}, #{createTime,jdbcType=TIMESTAMP},
#{loginFrequency,jdbcType=INTEGER}, #{loginDuration,jdbcType=DOUBLE}, #{createTime,jdbcType=TIMESTAMP},
#{updateTime,jdbcType=TIMESTAMP})
</insert>
<insert id="insertSelective" parameterType="com.sztzjy.marketing.entity.StuUserLoginActive">
@ -212,7 +212,7 @@
#{loginFrequency,jdbcType=INTEGER},
</if>
<if test="loginDuration != null">
#{loginDuration,jdbcType=VARCHAR},
#{loginDuration,jdbcType=DOUBLE},
</if>
<if test="createTime != null">
#{createTime,jdbcType=TIMESTAMP},
@ -228,8 +228,7 @@
<include refid="Example_Where_Clause" />
</if>
</select>
<update id="updateByExampleSelective" parameterType="map">
<update id="updateByExampleSelective" parameterType="map">
update stu_user_login_active
<set>
<if test="record.id != null">
@ -269,7 +268,7 @@
login_frequency = #{record.loginFrequency,jdbcType=INTEGER},
</if>
<if test="record.loginDuration != null">
login_duration = #{record.loginDuration,jdbcType=VARCHAR},
login_duration = #{record.loginDuration,jdbcType=DOUBLE},
</if>
<if test="record.createTime != null">
create_time = #{record.createTime,jdbcType=TIMESTAMP},
@ -296,7 +295,7 @@
frequent_login_location = #{record.frequentLoginLocation,jdbcType=VARCHAR},
last_login_date = #{record.lastLoginDate,jdbcType=TIMESTAMP},
login_frequency = #{record.loginFrequency,jdbcType=INTEGER},
login_duration = #{record.loginDuration,jdbcType=VARCHAR},
login_duration = #{record.loginDuration,jdbcType=DOUBLE},
create_time = #{record.createTime,jdbcType=TIMESTAMP},
update_time = #{record.updateTime,jdbcType=TIMESTAMP}
<if test="_parameter != null">
@ -340,7 +339,7 @@
login_frequency = #{loginFrequency,jdbcType=INTEGER},
</if>
<if test="loginDuration != null">
login_duration = #{loginDuration,jdbcType=VARCHAR},
login_duration = #{loginDuration,jdbcType=DOUBLE},
</if>
<if test="createTime != null">
create_time = #{createTime,jdbcType=TIMESTAMP},
@ -364,16 +363,9 @@
frequent_login_location = #{frequentLoginLocation,jdbcType=VARCHAR},
last_login_date = #{lastLoginDate,jdbcType=TIMESTAMP},
login_frequency = #{loginFrequency,jdbcType=INTEGER},
login_duration = #{loginDuration,jdbcType=VARCHAR},
login_duration = #{loginDuration,jdbcType=DOUBLE},
create_time = #{createTime,jdbcType=TIMESTAMP},
update_time = #{updateTime,jdbcType=TIMESTAMP}
where id = #{id,jdbcType=INTEGER}
</update>
<select id="selectByFields" resultType="com.sztzjy.marketing.entity.StuUserLoginActive" resultMap="BaseResultMap">
SELECT
<foreach collection="fieldList" item="field" separator=",">
${field}
</foreach>
FROM ${table}
</select>
</mapper>
Loading…
Cancel
Save