Java-用线程池以及CountDownLatch优化代码 提高执行效率
1. 问题描述
客户提了一个新需求,开发完成后发现查询一小时内的数据耗时要 7 秒,这客户肯定不满意,不满意就要和领导提,领导不开心了我就要被扣工资!所以就想利用线程池优化一下代码,提高方法的效率。
2. 初始代码
点击查看代码
// 查询所有站点
QueryWrapper<Station> stationQW = new QueryWrapper<>();
stationQW.lambda().eq(Station::getRegionCode, region);
List<Station> stations = this.stationMapper.selectList(stationQW);
List<StationPO> stationPOList = new ArrayList<>();
long start = System.currentTimeMillis();
for (Station station : stations) {
long methodStart = System.currentTimeMillis();
String stationCode = station.getStationCode();
StationPO stationPO = new StationPO();
BeanUtils.copyProperties(station, stationPO);
// 总雨量
Float rainFall = stationDataMapper.queryRainFallByTime(startTime, endTime, stationCode);
stationPO.setRainFall(rainFall);
// 平均气温
Float avgTemp = stationDataMapper.queryAvgTempByTime(startTime, endTime, stationCode);
stationPO.setAvgTemp(avgTemp);
// 最高气温
Float maxTemp = stationDataMapper.queryMaxTempByTime(startTime, endTime, stationCode);
stationPO.setMaxTemp(maxTemp);
// 最低气温
Float minTemp = stationDataMapper.queryMinTempByTime(startTime, endTime, stationCode);
stationPO.setMinTemp(minTemp);
// 最大风速
Float maxWind = stationDataMapper.queryMaxWindByTime(startTime, endTime, stationCode);
stationPO.setMaxWind(maxWind);
// 极大风速
Float enormousWind = stationDataMapper.queryEnormousWindByTime(startTime, endTime, stationCode);
stationPO.setEnormousWind(enormousWind);
// 平均湿度
Float avgHumidity = stationDataMapper.queryAvgHumidityByTime(startTime, endTime, stationCode);
stationPO.setAvgHumidity(avgHumidity);
long methodEnd = System.currentTimeMillis();
System.out.println("7个查询耗时:" + new BigDecimal(methodEnd - methodStart).divide(new BigDecimal(1000)).doubleValue());
stationPOList.add(stationPO);
}
long end = System.currentTimeMillis();
System.out.println("方法耗时:" + new BigDecimal(end - start).divide(new BigDecimal(1000)).doubleValue());
我这边站点数据集合的大小是37,每次循环都有7个SQL语句,每个SQL的执行时间在0.8秒左右,时间都浪费在循环上了,所以设想循环都创建一个线程去执行任务,这样的话总耗时也就是一次循环的时间。
3. 用到的技术
- ThreadPoolExecutor线程池
- CountDownLatch锁
这里简单说一下CountDownLatch锁,作用就是一个线程会等待其他线程都执行完毕后再继续执行,具体是通过一个计数器来实现的,计数器的初始值是线程的数量。每当一个线程执行完毕后,计数器的值就-1,当计数器的值为0时,表示所有线程都执行完毕,然后在闭锁上等待的线程就可以恢复工作了。
4. 整体思路
首先创建一个线程池,然后创建锁,这里我直接把线程池的大小以及锁的count都设置成list的大小,也就是循环次数,开始循环,for循环开启线程,执行一个站点的查询数据SQL,查询完成后关闭一个锁(countDown方法)。循环外面等待所有线程结束后(await方法),关闭线程池(shutdown方法),结束。
5. 优化后代码
点击查看代码
// 查询所有站点
QueryWrapper<Station> stationQW = new QueryWrapper<>();
stationQW.lambda().eq(Station::getRegionCode, region);
List<Station> stations = this.stationMapper.selectList(stationQW);
List<StationPO> stationPOList = new ArrayList<>();
ThreadPoolExecutor poolExecutor = ExecutorBuilder.create()
.setCorePoolSize(stations.size()) // 初始线程
.setMaxPoolSize(stations.size()) // 最大线程
.setWorkQueue(new LinkedBlockingQueue<>(100)) // 线程池策略
.build();
CountDownLatch cdl = new CountDownLatch(stations.size());
long start = System.currentTimeMillis();
for (Station station : stations) {
poolExecutor.execute(
() -> {
long methodStart = System.currentTimeMillis();
String stationCode = station.getStationCode();
StationPO stationPO = new StationPO();
BeanUtils.copyProperties(station, stationPO);
// 总雨量
Float rainFall = stationDataMapper.queryRainFallByTime(startTime, endTime, stationCode);
stationPO.setRainFall(rainFall);
// 平均气温
Float avgTemp = stationDataMapper.queryAvgTempByTime(startTime, endTime, stationCode);
stationPO.setAvgTemp(avgTemp);
// 最高气温
Float maxTemp = stationDataMapper.queryMaxTempByTime(startTime, endTime, stationCode);
stationPO.setMaxTemp(maxTemp);
// 最低气温
Float minTemp = stationDataMapper.queryMinTempByTime(startTime, endTime, stationCode);
stationPO.setMinTemp(minTemp);
// 最大风速
Float maxWind = stationDataMapper.queryMaxWindByTime(startTime, endTime, stationCode);
stationPO.setMaxWind(maxWind);
// 极大风速
Float enormousWind = stationDataMapper.queryEnormousWindByTime(startTime, endTime, stationCode);
stationPO.setEnormousWind(enormousWind);
// 平均湿度
Float avgHumidity = stationDataMapper.queryAvgHumidityByTime(startTime, endTime, stationCode);
stationPO.setAvgHumidity(avgHumidity);
long methodEnd = System.currentTimeMillis();
System.out.println("7个查询耗时:" + new BigDecimal(methodEnd - methodStart).divide(new BigDecimal(1000)).doubleValue());
stationPOList.add(stationPO);
// 闭锁-1
cdl.countDown();
}
);
}
try {
// 等待所有线程结束
cdl.await();
} catch (InterruptedException e) {
StaticLog.error("线程错误:{}",e.getMessage());
}
poolExecutor.shutdown();
long end = System.currentTimeMillis();
System.out.println("方法耗时:" + new BigDecimal(end - start).divide(new BigDecimal(1000)).doubleValue());