From ae3a1781e25eaa81922ac8423b1dffdb7d385e78 Mon Sep 17 00:00:00 2001 From: ToB213 Date: Wed, 22 Apr 2026 21:04:35 +0900 Subject: [PATCH] fix: replace O(N*K*N) BFS with multi-source Dijkstra in KMeansClustering --- .../module/algorithm/KMeansClustering.java | 218 +++++++++++------- 1 file changed, 140 insertions(+), 78 deletions(-) diff --git a/src/main/java/adf/impl/module/algorithm/KMeansClustering.java b/src/main/java/adf/impl/module/algorithm/KMeansClustering.java index 83b9937..4b56a52 100644 --- a/src/main/java/adf/impl/module/algorithm/KMeansClustering.java +++ b/src/main/java/adf/impl/module/algorithm/KMeansClustering.java @@ -12,15 +12,15 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; -import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.PriorityQueue; import java.util.Random; -import java.util.Set; import rescuecore2.misc.Pair; -import rescuecore2.misc.collections.LazyMap; import rescuecore2.misc.geometry.Point2D; import rescuecore2.standard.entities.Area; import rescuecore2.standard.entities.Blockade; @@ -52,7 +52,8 @@ public class KMeansClustering extends StaticClustering { private boolean assignAgentsFlag; - private Map> shortestPathGraph; + // Edge-weighted adjacency list: area ID -> (neighbour area ID -> travel cost) + private Map> weightedGraph; public KMeansClustering(AgentInfo ai, WorldInfo wi, ScenarioInfo si, ModuleManager moduleManager, DevelopData developData) { super(ai, wi, si, moduleManager, developData); @@ -203,7 +204,7 @@ public Clustering calc() { private void calcStandard(int repeat) { - this.initShortestPath(this.worldInfo); + this.initWeightedGraph(this.worldInfo); Random random = new Random(); List entityList = new ArrayList<>(this.entities); @@ -238,16 +239,18 @@ private void calcStandard(int repeat) { this.clusterEntitiesList.get(this.centerList.indexOf(tmp)).add(entity); } for (int index = 0; index < this.clusterSize; index++) { + List clusterEntities = this.clusterEntitiesList.get(index); + if (clusterEntities.isEmpty()) continue; int sumX = 0, sumY = 0; - for (StandardEntity entity : this.clusterEntitiesList.get(index)) { + for (StandardEntity entity : clusterEntities) { Pair location = this.worldInfo.getLocation(entity); sumX += location.first(); sumY += location.second(); } - int centerX = sumX / this.clusterEntitiesList.get(index).size(); - int centerY = sumY / this.clusterEntitiesList.get(index).size(); + int centerX = sumX / clusterEntities.size(); + int centerY = sumY / clusterEntities.size(); StandardEntity center = this.getNearEntityByLine(this.worldInfo, - this.clusterEntitiesList.get(index), centerX, centerY); + clusterEntities, centerX, centerY); if (center instanceof Area) { this.centerList.set(index, center); } else if (center instanceof Human) { @@ -278,8 +281,6 @@ private void calcStandard(int repeat) { this.clusterEntitiesList.get(this.centerList.indexOf(tmp)).add(entity); } - // this.clusterEntitiesList.sort(comparing(List::size, reverseOrder())); - if (this.assignAgentsFlag) { List firebrigadeList = new ArrayList<>( this.worldInfo.getEntitiesOfType(StandardEntityURN.FIRE_BRIGADE)); @@ -309,7 +310,7 @@ private void calcStandard(int repeat) { private void calcPathBased(int repeat) { - this.initShortestPath(this.worldInfo); + this.initWeightedGraph(this.worldInfo); Random random = new Random(); List entityList = new ArrayList<>(this.entities); this.centerList = new ArrayList<>(this.clusterSize); @@ -327,30 +328,39 @@ private void calcPathBased(int repeat) { } while (this.centerList.contains(centerEntity)); this.centerList.set(index, centerEntity); } + for (int i = 0; i < repeat; i++) { this.clusterEntitiesList.clear(); for (int index = 0; index < this.clusterSize; index++) { this.clusterEntitiesList.put(index, new ArrayList<>()); } + + // Assign each entity to its nearest center using multi-source Dijkstra. + // All K centers are expanded simultaneously on the weighted road graph, + // so each entity is assigned to the center with the shortest path distance. + // Complexity: O(N log N) per iteration (formerly: O(N * K * N)) + Map assignment = this.assignByMultiSourceDijkstra(this.centerList); for (StandardEntity entity : entityList) { - StandardEntity tmp = this.getNearEntity(this.worldInfo, this.centerList, - entity); - this.clusterEntitiesList.get(this.centerList.indexOf(tmp)).add(entity); + Integer clusterIndex = assignment.get(entity.getID()); + if (clusterIndex != null) { + this.clusterEntitiesList.get(clusterIndex).add(entity); + } } + + // Update centers: move each center to the entity nearest to the cluster centroid for (int index = 0; index < this.clusterSize; index++) { + List clusterEntities = this.clusterEntitiesList.get(index); + if (clusterEntities.isEmpty()) continue; int sumX = 0, sumY = 0; - for (StandardEntity entity : this.clusterEntitiesList.get(index)) { + for (StandardEntity entity : clusterEntities) { Pair location = this.worldInfo.getLocation(entity); sumX += location.first(); sumY += location.second(); } - int centerX = sumX / clusterEntitiesList.get(index).size(); - int centerY = sumY / clusterEntitiesList.get(index).size(); - - // this.centerList.set(index, getNearEntity(this.worldInfo, - // this.clusterEntitiesList.get(index), centerX, centerY)); - StandardEntity center = this.getNearEntity(this.worldInfo, - this.clusterEntitiesList.get(index), centerX, centerY); + int centerX = sumX / clusterEntities.size(); + int centerY = sumY / clusterEntities.size(); + StandardEntity center = this.getNearEntityByLine(this.worldInfo, + clusterEntities, centerX, centerY); if (center instanceof Area) { this.centerList.set(index, center); } else if (center instanceof Human) { @@ -361,6 +371,7 @@ private void calcPathBased(int repeat) { this.worldInfo.getEntity(((Blockade) center).getPosition())); } } + if (scenarioInfo.isDebugMode()) { System.out.print("*"); } @@ -370,16 +381,19 @@ private void calcPathBased(int repeat) { System.out.println(); } + // Final assignment: run multi-source Dijkstra once more with the converged centers this.clusterEntitiesList.clear(); for (int index = 0; index < this.clusterSize; index++) { this.clusterEntitiesList.put(index, new ArrayList<>()); } + Map finalAssignment = this.assignByMultiSourceDijkstra(this.centerList); for (StandardEntity entity : entityList) { - StandardEntity tmp = this.getNearEntity(this.worldInfo, this.centerList, - entity); - this.clusterEntitiesList.get(this.centerList.indexOf(tmp)).add(entity); + Integer clusterIndex = finalAssignment.get(entity.getID()); + if (clusterIndex != null) { + this.clusterEntitiesList.get(clusterIndex).add(entity); + } } - // this.clusterEntitiesList.sort(comparing(List::size, reverseOrder())); + if (this.assignAgentsFlag) { List fireBrigadeList = new ArrayList<>( this.worldInfo.getEntitiesOfType(StandardEntityURN.FIRE_BRIGADE)); @@ -407,6 +421,99 @@ private void calcPathBased(int repeat) { } + /** + * Voronoi partition using multi-source Dijkstra. + * + * All K centers are seeded into a priority queue at distance 0 and expanded + * simultaneously using edge-weighted shortest-path search. Each entity is + * assigned to the cluster whose center has the minimum path distance to it. + * + * Complexity: O(N log N) (formerly O(N * K * N) with comparePathDistance) + * + * @param centers list of cluster centers + * @return map from EntityID to cluster index + */ + private Map assignByMultiSourceDijkstra(List centers) { + Map dist = new HashMap<>(); + Map assignment = new HashMap<>(); + // Entry: [distance, clusterIndex, entityID_value] + PriorityQueue pq = new PriorityQueue<>( + Comparator.comparingDouble(a -> a[0])); + + for (int i = 0; i < centers.size(); i++) { + EntityID cid = centers.get(i).getID(); + if (!dist.containsKey(cid)) { + dist.put(cid, 0.0); + assignment.put(cid, i); + pq.offer(new double[]{0.0, i, cid.getValue()}); + } + } + + while (!pq.isEmpty()) { + double[] cur = pq.poll(); + double d = cur[0]; + int ci = (int) cur[1]; + EntityID uid = new EntityID((int) cur[2]); + + if (d > dist.getOrDefault(uid, Double.MAX_VALUE)) continue; + assignment.putIfAbsent(uid, ci); + + Map neighbours = weightedGraph.get(uid); + if (neighbours == null) continue; + for (Map.Entry entry : neighbours.entrySet()) { + double nd = d + entry.getValue(); + if (nd < dist.getOrDefault(entry.getKey(), Double.MAX_VALUE)) { + dist.put(entry.getKey(), nd); + assignment.put(entry.getKey(), ci); + pq.offer(new double[]{nd, ci, entry.getKey().getValue()}); + } + } + } + + return assignment; + } + + + /** + * Build an edge-weighted adjacency graph. + * + * The travel cost between adjacent areas A and B is defined as: + * cost(A -> B) = dist(A.center, midpoint of the A-B edge) + * + dist(B.center, midpoint of the B-A edge) + * + * This matches the per-edge cost used in getPathDistance(), ensuring that + * assignByMultiSourceDijkstra produces assignments equivalent to the original + * path-distance Voronoi partition. + * + * Replaces the former initShortestPath() (unweighted adjacency graph). + */ + private void initWeightedGraph(WorldInfo worldInfo) { + this.weightedGraph = new HashMap<>(); + for (Entity entity : worldInfo) { + if (!(entity instanceof Area)) continue; + Area area = (Area) entity; + Pair aCenter = worldInfo.getLocation(area); + Map neighbours = new HashMap<>(); + + for (EntityID neighbourId : area.getNeighbours()) { + Entity neighbourEntity = worldInfo.getEntity(neighbourId); + if (!(neighbourEntity instanceof Area)) continue; + Area neighbour = (Area) neighbourEntity; + + Edge edgeFromArea = area.getEdgeTo(neighbourId); + Edge edgeFromNeighbour = neighbour.getEdgeTo(area.getID()); + if (edgeFromArea == null || edgeFromNeighbour == null) continue; + + Pair nCenter = worldInfo.getLocation(neighbour); + double weight = getDistance(aCenter, edgeFromArea) + + getDistance(nCenter, edgeFromNeighbour); + neighbours.put(neighbourId, weight); + } + this.weightedGraph.put(area.getID(), neighbours); + } + } + + private void assignAgents(WorldInfo world, List agentList) { int clusterIndex = 0; while (agentList.size() > 0) { @@ -462,19 +569,6 @@ private StandardEntity getNearAgent(WorldInfo worldInfo, } - private StandardEntity getNearEntity(WorldInfo worldInfo, - List srcEntityList, int targetX, int targetY) { - StandardEntity result = null; - for (StandardEntity entity : srcEntityList) { - result = (result != null) - ? this.compareLineDistance(worldInfo, targetX, targetY, result, - entity) - : entity; - } - return result; - } - - private Point2D getEdgePoint(Edge edge) { Point2D start = edge.getStart(); Point2D end = edge.getEnd(); @@ -523,18 +617,7 @@ private StandardEntity compareLineDistance(WorldInfo worldInfo, int targetX, } - private StandardEntity getNearEntity(WorldInfo worldInfo, - List srcEntityList, StandardEntity targetEntity) { - StandardEntity result = null; - for (StandardEntity entity : srcEntityList) { - result = (result != null) - ? this.comparePathDistance(worldInfo, targetEntity, result, entity) - : entity; - } - return result; - } - - + // Used by getNearAgent() for initial agent placement only. private StandardEntity comparePathDistance(WorldInfo worldInfo, StandardEntity target, StandardEntity first, StandardEntity second) { double firstDistance = getPathDistance(worldInfo, @@ -570,36 +653,13 @@ private double getPathDistance(WorldInfo worldInfo, List path) { } - private void initShortestPath(WorldInfo worldInfo) { - Map> neighbours = new LazyMap>() { - - @Override - public Set createValue() { - return new HashSet<>(); - } - }; - for (Entity next : worldInfo) { - if (next instanceof Area) { - Collection areaNeighbours = ((Area) next).getNeighbours(); - neighbours.get(next.getID()).addAll(areaNeighbours); - } - } - for (Map.Entry> graph : neighbours.entrySet()) {// fix - // graph - for (EntityID entityID : graph.getValue()) { - neighbours.get(entityID).add(graph.getKey()); - } - } - this.shortestPathGraph = neighbours; - } - - private List shortestPath(EntityID start, EntityID... goals) { return shortestPath(start, Arrays.asList(goals)); } + // BFS used by comparePathDistance() for getNearAgent(). + // Reuses weightedGraph as an adjacency list (edge weights are ignored here). private List shortestPath(EntityID start, Collection goals) { List open = new LinkedList<>(); @@ -614,7 +674,9 @@ private List shortestPath(EntityID start, found = true; break; } - Collection neighbours = shortestPathGraph.get(next); + Collection neighbours = weightedGraph.containsKey(next) + ? weightedGraph.get(next).keySet() + : Collections.emptySet(); if (neighbours.isEmpty()) continue;