Skip to content

Commit

Permalink
update speed models, gradient based optimization instead of optuna
Browse files Browse the repository at this point in the history
  • Loading branch information
rakow committed Aug 11, 2023
1 parent 88afd9b commit a2ee4af
Show file tree
Hide file tree
Showing 10 changed files with 5,619 additions and 2,458 deletions.
19 changes: 19 additions & 0 deletions src/main/java/org/matsim/prepare/network/FeatureRegressor.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@
*/
public interface FeatureRegressor {


/**
* Predict value from given features.
*/
double predict(Object2DoubleMap<String> ft);

/**
* Predict values with adjusted model params.
*/
default double predict(Object2DoubleMap<String> ft, double[] params) {
throw new UnsupportedOperationException("Not implemented");
}


/**
* Return data that is used for internal prediction function (normalization already applied).
*/
default double[] getData(Object2DoubleMap<String> ft) {
throw new UnsupportedOperationException("Not implemented");
}

}
117 changes: 83 additions & 34 deletions src/main/java/org/matsim/prepare/network/FreeSpeedOptimizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleDoublePair;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.Int2ObjectLinkedOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
Expand Down Expand Up @@ -45,17 +44,16 @@
import java.io.PrintWriter;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.DoubleStream;

@CommandLine.Command(
name = "network-freespeed",
description = "Start server for freespeed optimization."
)
@CommandSpec(
requireNetwork = true
requireNetwork = true,
requires = "features.csv"
)
public class FreeSpeedOptimizer implements MATSimAppCommand {

Expand All @@ -75,6 +73,7 @@ public class FreeSpeedOptimizer implements MATSimAppCommand {

private Network network;
private Object2DoubleMap<Entry> validationSet;
private Map<Id<Link>, PrepareNetworkParams.Feature> features;

private ObjectMapper mapper;

Expand All @@ -99,7 +98,8 @@ public Integer call() throws Exception {
speeds.put(link.getId(), link.getFreespeed());
}

validationSet = readValidation();
validationSet = readValidation(validationFiles);
features = PrepareNetworkParams.readFeatures(input.getPath("features.csv"), network.getLinks().size());

log.info("Initial score:");
evaluateNetwork(null, "init");
Expand Down Expand Up @@ -133,29 +133,45 @@ public Integer call() throws Exception {
return 0;
}

private DoubleDoublePair evaluateNetwork(Request request, String save) throws IOException {
private Result evaluateNetwork(Request request, String save) throws IOException {

Map<Id<Link>, double[]> attributes = new HashMap<>();

if (request != null) {
for (Link link : network.getLinks().values()) {

double allowedSpeed = NetworkUtils.getAllowedSpeed(link);
double speed = speeds.getDouble(link.getId());

if (request.f == 0) {
double speedFactor = (double) link.getAttributes().getAttribute("speed_factor");

if (allowedSpeed <= 31 / 3.6) {
link.setFreespeed(speed * request.b30);
link.getAttributes().putAttribute("speed_factor", speedFactor * request.b30);

} else if (allowedSpeed <= 51 / 3.6) {
link.setFreespeed(speed * request.b50);
link.getAttributes().putAttribute("speed_factor", speedFactor * request.b50);
} else if (allowedSpeed <= 91 / 3.6) {
link.setFreespeed(speed * request.b90);
link.getAttributes().putAttribute("speed_factor", speedFactor * request.b90);

PrepareNetworkParams.Feature ft = features.get(link.getId());
String type = NetworkUtils.getHighwayType(link);

if (type.startsWith("motorway")) {
link.setFreespeed(allowedSpeed);
continue;
}

FeatureRegressor speedModel = switch (ft.junctionType()) {
case "traffic_light" -> Speedrelative_traffic_light.INSTANCE;
case "right_before_left" -> Speedrelative_right_before_left.INSTANCE;
case "priority" -> Speedrelative_priority.INSTANCE;
default -> throw new IllegalArgumentException("Unknown type: " + ft.junctionType());
};

double[] p = switch (ft.junctionType()) {
case "traffic_light" -> request.traffic_light;
case "right_before_left" -> request.rbl;
case "priority" -> request.priority;
default -> throw new IllegalArgumentException("Unknown type: " + ft.junctionType());
};

double speedFactor = Math.max(0.25, speedModel.predict(ft.features(), p));

attributes.put(link.getId(), speedModel.getData(ft.features()));

link.setFreespeed((double) link.getAttributes().getAttribute("allowed_speed") * speedFactor);
link.getAttributes().putAttribute("speed_factor", speedFactor);

} else
// Old MATSim freespeed logic
Expand All @@ -178,22 +194,49 @@ private DoubleDoublePair evaluateNetwork(Request request, String save) throws IO
if (csv != null)
csv.printRecord("from_node", "to_node", "beeline_dist", "dist", "travel_time");

List<Data> priority = new ArrayList<>();
List<Data> rbl = new ArrayList<>();
List<Data> traffic_light = new ArrayList<>();

for (Object2DoubleMap.Entry<Entry> e : validationSet.object2DoubleEntrySet()) {

Entry r = e.getKey();

Node fromNode = network.getNodes().get(r.fromNode);
Node toNode = network.getNodes().get(r.toNode);
Node fromNode = network.getNodes().get(r.fromNode());
Node toNode = network.getNodes().get(r.toNode());
LeastCostPathCalculator.Path path = router.calcLeastCostPath(fromNode, toNode, 0, null, null);

// iterate over the path, calc better correction
double distance = path.links.stream().mapToDouble(Link::getLength).sum();
double speed = distance / path.travelTime;

double correction = speed / e.getDoubleValue();

for (Link link : path.links) {

if (!attributes.containsKey(link.getId()))
continue;

PrepareNetworkParams.Feature ft = features.get(link.getId());
double[] input = attributes.get(link.getId());
double speedFactor = (double) link.getAttributes().getAttribute("speed_factor");

List<Data> category = switch (ft.junctionType()) {
case "traffic_light" -> traffic_light;
case "right_before_left" -> rbl;
case "priority" -> priority;
default -> throw new IllegalArgumentException("not happening");
};

category.add(new Data(input, speedFactor, speedFactor / correction));
}


rmse.addValue(Math.pow(e.getDoubleValue() - speed, 2));
mse.addValue(Math.abs((e.getDoubleValue() - speed) * 3.6));

if (csv != null)
csv.printRecord(r.fromNode, r.toNode, (int) CoordUtils.calcEuclideanDistance(fromNode.getCoord(), toNode.getCoord()),
csv.printRecord(r.fromNode(), r.toNode(), (int) CoordUtils.calcEuclideanDistance(fromNode.getCoord(), toNode.getCoord()),
(int) distance, (int) path.travelTime);
}

Expand All @@ -202,13 +245,13 @@ private DoubleDoublePair evaluateNetwork(Request request, String save) throws IO

log.info("{}, rmse: {}, mae: {}", request, rmse.getMean(), mse.getMean());

return DoubleDoublePair.of(rmse.getMean(), mse.getMean());
return new Result(rmse.getMean(), mse.getMean(), priority, rbl, traffic_light);
}

/**
* Collect highest observed speed.
*/
private Object2DoubleMap<Entry> readValidation() throws IOException {
static Object2DoubleMap<Entry> readValidation(List<String> validationFiles) throws IOException {

// entry to hour and list of speeds
Map<Entry, Int2ObjectMap<DoubleList>> entries = new LinkedHashMap<>();
Expand Down Expand Up @@ -272,14 +315,21 @@ private Object2DoubleMap<Entry> readValidation() throws IOException {
private record Entry(Id<Node> fromNode, Id<Node> toNode) {
}

private record Data(double[] x, double yPred, double yTrue) {

}

private record Result(double rmse, double mse, List<Data> priority, List<Data> rbl, List<Data> traffic_light) {}


/**
* JSON request containing desired parameters.
*/
private static final class Request {

double b30;
double b50;
double b90;
double[] priority;
double[] rbl;
double[] traffic_light;

double f;

Expand All @@ -294,9 +344,9 @@ public Request(double f) {
public String toString() {
if (f == 0)
return "Request{" +
"b30=" + b30 +
", b50=" + b50 +
", b90=" + b90 +
"priority=" + priority.length +
", rbl=" + rbl.length +
", traffic_light=" + traffic_light.length +
'}';

return "Request{f=" + f + "}";
Expand All @@ -311,7 +361,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I
Request request = mapper.readValue(req.getInputStream(), Request.class);

boolean save = req.getRequestURI().equals("/save");
DoubleDoublePair stats = evaluateNetwork(request, save ? "network-opt" : null);
Result stats = evaluateNetwork(request, save ? "network-opt" : null);

if (save)
NetworkUtils.writeNetwork(network, "network-opt.xml.gz");
Expand All @@ -320,8 +370,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws I

PrintWriter writer = resp.getWriter();

// target value
writer.println(stats.rightDouble());
mapper.writeValue(writer, stats);

writer.close();
}
Expand Down
4,245 changes: 3,360 additions & 885 deletions src/main/java/org/matsim/prepare/network/Speedrelative_priority.java

Large diffs are not rendered by default.

Loading

0 comments on commit a2ee4af

Please sign in to comment.