Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 140 additions & 78 deletions src/main/java/adf/impl/module/algorithm/KMeansClustering.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import rescuecore2.misc.Pair;
import rescuecore2.misc.collections.LazyMap;
import rescuecore2.misc.geometry.Point2D;
import rescuecore2.standard.entities.Area;
import rescuecore2.standard.entities.Blockade;
Expand Down Expand Up @@ -52,7 +52,8 @@ public class KMeansClustering extends StaticClustering {

private boolean assignAgentsFlag;

private Map<EntityID, Set<EntityID>> shortestPathGraph;
// Edge-weighted adjacency list: area ID -> (neighbour area ID -> travel cost)
private Map<EntityID, Map<EntityID, Double>> weightedGraph;

public KMeansClustering(AgentInfo ai, WorldInfo wi, ScenarioInfo si, ModuleManager moduleManager, DevelopData developData) {
super(ai, wi, si, moduleManager, developData);
Expand Down Expand Up @@ -203,7 +204,7 @@ public Clustering calc() {


private void calcStandard(int repeat) {
this.initShortestPath(this.worldInfo);
this.initWeightedGraph(this.worldInfo);
Random random = new Random();

List<StandardEntity> entityList = new ArrayList<>(this.entities);
Expand Down Expand Up @@ -238,16 +239,18 @@ private void calcStandard(int repeat) {
this.clusterEntitiesList.get(this.centerList.indexOf(tmp)).add(entity);
}
for (int index = 0; index < this.clusterSize; index++) {
List<StandardEntity> clusterEntities = this.clusterEntitiesList.get(index);
if (clusterEntities.isEmpty()) continue;
int sumX = 0, sumY = 0;
for (StandardEntity entity : this.clusterEntitiesList.get(index)) {
for (StandardEntity entity : clusterEntities) {
Pair<Integer, Integer> location = this.worldInfo.getLocation(entity);
sumX += location.first();
sumY += location.second();
}
int centerX = sumX / this.clusterEntitiesList.get(index).size();
int centerY = sumY / this.clusterEntitiesList.get(index).size();
int centerX = sumX / clusterEntities.size();
int centerY = sumY / clusterEntities.size();
StandardEntity center = this.getNearEntityByLine(this.worldInfo,
this.clusterEntitiesList.get(index), centerX, centerY);
clusterEntities, centerX, centerY);
if (center instanceof Area) {
this.centerList.set(index, center);
} else if (center instanceof Human) {
Expand Down Expand Up @@ -278,8 +281,6 @@ private void calcStandard(int repeat) {
this.clusterEntitiesList.get(this.centerList.indexOf(tmp)).add(entity);
}

// this.clusterEntitiesList.sort(comparing(List::size, reverseOrder()));

if (this.assignAgentsFlag) {
List<StandardEntity> firebrigadeList = new ArrayList<>(
this.worldInfo.getEntitiesOfType(StandardEntityURN.FIRE_BRIGADE));
Expand Down Expand Up @@ -309,7 +310,7 @@ private void calcStandard(int repeat) {


private void calcPathBased(int repeat) {
this.initShortestPath(this.worldInfo);
this.initWeightedGraph(this.worldInfo);
Random random = new Random();
List<StandardEntity> entityList = new ArrayList<>(this.entities);
this.centerList = new ArrayList<>(this.clusterSize);
Expand All @@ -327,30 +328,39 @@ private void calcPathBased(int repeat) {
} while (this.centerList.contains(centerEntity));
this.centerList.set(index, centerEntity);
}

for (int i = 0; i < repeat; i++) {
this.clusterEntitiesList.clear();
for (int index = 0; index < this.clusterSize; index++) {
this.clusterEntitiesList.put(index, new ArrayList<>());
}

// Assign each entity to its nearest center using multi-source Dijkstra.
// All K centers are expanded simultaneously on the weighted road graph,
// so each entity is assigned to the center with the shortest path distance.
// Complexity: O(N log N) per iteration (formerly: O(N * K * N))
Map<EntityID, Integer> assignment = this.assignByMultiSourceDijkstra(this.centerList);
for (StandardEntity entity : entityList) {
StandardEntity tmp = this.getNearEntity(this.worldInfo, this.centerList,
entity);
this.clusterEntitiesList.get(this.centerList.indexOf(tmp)).add(entity);
Integer clusterIndex = assignment.get(entity.getID());
if (clusterIndex != null) {
this.clusterEntitiesList.get(clusterIndex).add(entity);
}
}

// Update centers: move each center to the entity nearest to the cluster centroid
for (int index = 0; index < this.clusterSize; index++) {
List<StandardEntity> clusterEntities = this.clusterEntitiesList.get(index);
if (clusterEntities.isEmpty()) continue;
int sumX = 0, sumY = 0;
for (StandardEntity entity : this.clusterEntitiesList.get(index)) {
for (StandardEntity entity : clusterEntities) {
Pair<Integer, Integer> location = this.worldInfo.getLocation(entity);
sumX += location.first();
sumY += location.second();
}
int centerX = sumX / clusterEntitiesList.get(index).size();
int centerY = sumY / clusterEntitiesList.get(index).size();

// this.centerList.set(index, getNearEntity(this.worldInfo,
// this.clusterEntitiesList.get(index), centerX, centerY));
StandardEntity center = this.getNearEntity(this.worldInfo,
this.clusterEntitiesList.get(index), centerX, centerY);
int centerX = sumX / clusterEntities.size();
int centerY = sumY / clusterEntities.size();
StandardEntity center = this.getNearEntityByLine(this.worldInfo,
clusterEntities, centerX, centerY);
if (center instanceof Area) {
this.centerList.set(index, center);
} else if (center instanceof Human) {
Expand All @@ -361,6 +371,7 @@ private void calcPathBased(int repeat) {
this.worldInfo.getEntity(((Blockade) center).getPosition()));
}
}

if (scenarioInfo.isDebugMode()) {
System.out.print("*");
}
Expand All @@ -370,16 +381,19 @@ private void calcPathBased(int repeat) {
System.out.println();
}

// Final assignment: run multi-source Dijkstra once more with the converged centers
this.clusterEntitiesList.clear();
for (int index = 0; index < this.clusterSize; index++) {
this.clusterEntitiesList.put(index, new ArrayList<>());
}
Map<EntityID, Integer> finalAssignment = this.assignByMultiSourceDijkstra(this.centerList);
for (StandardEntity entity : entityList) {
StandardEntity tmp = this.getNearEntity(this.worldInfo, this.centerList,
entity);
this.clusterEntitiesList.get(this.centerList.indexOf(tmp)).add(entity);
Integer clusterIndex = finalAssignment.get(entity.getID());
if (clusterIndex != null) {
this.clusterEntitiesList.get(clusterIndex).add(entity);
}
}
// this.clusterEntitiesList.sort(comparing(List::size, reverseOrder()));

if (this.assignAgentsFlag) {
List<StandardEntity> fireBrigadeList = new ArrayList<>(
this.worldInfo.getEntitiesOfType(StandardEntityURN.FIRE_BRIGADE));
Expand Down Expand Up @@ -407,6 +421,99 @@ private void calcPathBased(int repeat) {
}


/**
* Voronoi partition using multi-source Dijkstra.
*
* All K centers are seeded into a priority queue at distance 0 and expanded
* simultaneously using edge-weighted shortest-path search. Each entity is
* assigned to the cluster whose center has the minimum path distance to it.
*
* Complexity: O(N log N) (formerly O(N * K * N) with comparePathDistance)
*
* @param centers list of cluster centers
* @return map from EntityID to cluster index
*/
private Map<EntityID, Integer> assignByMultiSourceDijkstra(List<StandardEntity> centers) {
Map<EntityID, Double> dist = new HashMap<>();
Map<EntityID, Integer> assignment = new HashMap<>();
// Entry: [distance, clusterIndex, entityID_value]
PriorityQueue<double[]> pq = new PriorityQueue<>(
Comparator.comparingDouble(a -> a[0]));

for (int i = 0; i < centers.size(); i++) {
EntityID cid = centers.get(i).getID();
if (!dist.containsKey(cid)) {
dist.put(cid, 0.0);
assignment.put(cid, i);
pq.offer(new double[]{0.0, i, cid.getValue()});
}
}

while (!pq.isEmpty()) {
double[] cur = pq.poll();
double d = cur[0];
int ci = (int) cur[1];
EntityID uid = new EntityID((int) cur[2]);

if (d > dist.getOrDefault(uid, Double.MAX_VALUE)) continue;
assignment.putIfAbsent(uid, ci);

Map<EntityID, Double> neighbours = weightedGraph.get(uid);
if (neighbours == null) continue;
for (Map.Entry<EntityID, Double> entry : neighbours.entrySet()) {
double nd = d + entry.getValue();
if (nd < dist.getOrDefault(entry.getKey(), Double.MAX_VALUE)) {
dist.put(entry.getKey(), nd);
assignment.put(entry.getKey(), ci);
pq.offer(new double[]{nd, ci, entry.getKey().getValue()});
}
}
}

return assignment;
}


/**
* Build an edge-weighted adjacency graph.
*
* The travel cost between adjacent areas A and B is defined as:
* cost(A -> B) = dist(A.center, midpoint of the A-B edge)
* + dist(B.center, midpoint of the B-A edge)
*
* This matches the per-edge cost used in getPathDistance(), ensuring that
* assignByMultiSourceDijkstra produces assignments equivalent to the original
* path-distance Voronoi partition.
*
* Replaces the former initShortestPath() (unweighted adjacency graph).
*/
private void initWeightedGraph(WorldInfo worldInfo) {
this.weightedGraph = new HashMap<>();
for (Entity entity : worldInfo) {
if (!(entity instanceof Area)) continue;
Area area = (Area) entity;
Pair<Integer, Integer> aCenter = worldInfo.getLocation(area);
Map<EntityID, Double> neighbours = new HashMap<>();

for (EntityID neighbourId : area.getNeighbours()) {
Entity neighbourEntity = worldInfo.getEntity(neighbourId);
if (!(neighbourEntity instanceof Area)) continue;
Area neighbour = (Area) neighbourEntity;

Edge edgeFromArea = area.getEdgeTo(neighbourId);
Edge edgeFromNeighbour = neighbour.getEdgeTo(area.getID());
if (edgeFromArea == null || edgeFromNeighbour == null) continue;

Pair<Integer, Integer> nCenter = worldInfo.getLocation(neighbour);
double weight = getDistance(aCenter, edgeFromArea)
+ getDistance(nCenter, edgeFromNeighbour);
neighbours.put(neighbourId, weight);
}
this.weightedGraph.put(area.getID(), neighbours);
}
}


private void assignAgents(WorldInfo world, List<StandardEntity> agentList) {
int clusterIndex = 0;
while (agentList.size() > 0) {
Expand Down Expand Up @@ -462,19 +569,6 @@ private StandardEntity getNearAgent(WorldInfo worldInfo,
}


private StandardEntity getNearEntity(WorldInfo worldInfo,
List<StandardEntity> srcEntityList, int targetX, int targetY) {
StandardEntity result = null;
for (StandardEntity entity : srcEntityList) {
result = (result != null)
? this.compareLineDistance(worldInfo, targetX, targetY, result,
entity)
: entity;
}
return result;
}


private Point2D getEdgePoint(Edge edge) {
Point2D start = edge.getStart();
Point2D end = edge.getEnd();
Expand Down Expand Up @@ -523,18 +617,7 @@ private StandardEntity compareLineDistance(WorldInfo worldInfo, int targetX,
}


private StandardEntity getNearEntity(WorldInfo worldInfo,
List<StandardEntity> srcEntityList, StandardEntity targetEntity) {
StandardEntity result = null;
for (StandardEntity entity : srcEntityList) {
result = (result != null)
? this.comparePathDistance(worldInfo, targetEntity, result, entity)
: entity;
}
return result;
}


// Used by getNearAgent() for initial agent placement only.
private StandardEntity comparePathDistance(WorldInfo worldInfo,
StandardEntity target, StandardEntity first, StandardEntity second) {
double firstDistance = getPathDistance(worldInfo,
Expand Down Expand Up @@ -570,36 +653,13 @@ private double getPathDistance(WorldInfo worldInfo, List<EntityID> path) {
}


private void initShortestPath(WorldInfo worldInfo) {
Map<EntityID,
Set<EntityID>> neighbours = new LazyMap<EntityID, Set<EntityID>>() {

@Override
public Set<EntityID> createValue() {
return new HashSet<>();
}
};
for (Entity next : worldInfo) {
if (next instanceof Area) {
Collection<EntityID> areaNeighbours = ((Area) next).getNeighbours();
neighbours.get(next.getID()).addAll(areaNeighbours);
}
}
for (Map.Entry<EntityID, Set<EntityID>> graph : neighbours.entrySet()) {// fix
// graph
for (EntityID entityID : graph.getValue()) {
neighbours.get(entityID).add(graph.getKey());
}
}
this.shortestPathGraph = neighbours;
}


private List<EntityID> shortestPath(EntityID start, EntityID... goals) {
return shortestPath(start, Arrays.asList(goals));
}


// BFS used by comparePathDistance() for getNearAgent().
// Reuses weightedGraph as an adjacency list (edge weights are ignored here).
private List<EntityID> shortestPath(EntityID start,
Collection<EntityID> goals) {
List<EntityID> open = new LinkedList<>();
Expand All @@ -614,7 +674,9 @@ private List<EntityID> shortestPath(EntityID start,
found = true;
break;
}
Collection<EntityID> neighbours = shortestPathGraph.get(next);
Collection<EntityID> neighbours = weightedGraph.containsKey(next)
? weightedGraph.get(next).keySet()
: Collections.emptySet();
if (neighbours.isEmpty())
continue;

Expand Down