Java Timing Code: Compare Execution Times of Methods

Often times, you will have two Java functions, and you will want to know which one performs better. You can use this Java class to time multiple methods to find out which one is faster.

Example Usage: Java Timing Code

Let’s say you have these two functions and want to know which one is faster:

    public static String function1(String i) {
        return i + i;
    }

    public static String function2(String i) {
        StringBuilder sb = new StringBuilder();
        sb.append(i);
        sb.append(i);
        return sb.toString();
    }

Here is how you would use my timing class:

    public static void simpleTimingTest() throws Exception {
        Timing t = new Timing();

        /*
        The timing function will run the tests in chunks.
        For each chunk, the same input will be used.

        Input will be gathered via the passed IntFunction. The timing method will
        pass the current index (going from 0 to amountRunsPerChunk) to it and will expect
        any return of the defined type.
        */

        IntFunction<String> inputProvider = i -> String.valueOf(i);

        /*
        The add method expects two functions: the above mentioned input provider,
        as well as a function which accepts the output from the input provider as input
        and applies it to the function which will be timed.
        */

        t.add((String s) -> function1(s), inputProvider, "function1 ");
        t.add((String s) -> function2(s), inputProvider, "function2 ");

        t.time(true); // true: force test (otherwise, time might throw an exception
                      // if it suspects that there isn't enough memory)
        t.output(s -> System.out.println(s), Timing.Sort.ASC);
    }

You can also do more complex things with it:

    public static void predefinedInputTimingTest() throws Exception {
        Timing t = new Timing();
        /*
        The input doesn't have to be generated using the passed index, you could
        also use predefined input to time your functions:
        */

        String[] input = new String[]{"input1", "another input", "more input"};
        IntFunction<String> inputProvider = i -> input[i % input.length];
        t.add((String s) -> function1(s), inputProvider, "function1 ");
        t.add((String s) -> function2(s), inputProvider, "function2 ");

        /*
        You can decide what should be reported when timing finished:
        */

        t.setReport(EnumSet.of(Timing.Report.NAME, Timing.Report.MEAN));
        t.setAmountChunks(1_500);
        t.setAmountRunsPerChunk(2_500);
        t.time(true, s -> System.out.println(s)); // pass String -> String function to report debug information
        t.output(s -> System.out.println(s), Timing.Sort.ASC);
    }

Example Output of Timing Result

the output of the above example code to time a Java function would for example look like this:

name           per call (mean, ns)     per call (median, ns)     95th percentile (ns)     total (ms)     runs    
function1      82                      73                        65                       0.0656         1000000  
function2      110                     99                        84                       0.088          1000000

Source Code: Timing Java Functions

The source code uses some Java 8, but it should be very easy to adapt in case you do not have Java 8.

The FormatedTableBuilder class to format output can be found here.

import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import output.FormatedTableBuilder;

/**
 * A timing class.
 *
 * Timings are run in chunks. For each chunk, the input is gathered before-hand.
 */

public class Timing {

    public static enum Sort {
        ASC, DESC, NAME
    }

    public static enum Report {
        NAME, MEAN, MEDIAN, PERCENTILE, TOTAL, RUNS;
        public static final EnumSet<Report> ALL = EnumSet.allOf(Report.class);
    }

    private List<TimingObject> functionsToTime = new ArrayList<>();

    /**
     * amount of chunks to run.
     */

    private int amountChunks = 1_000;

    /**
     * amount of runs per chunk.
     */

    private int amountRunsPerChunk = 1_000;

    private EnumSet<Report> report = Report.ALL;

    /**
     * adds a new function which will be timed.
     *
     * @param <R> return type of functionToTime (irrelevant)
     * @param <T> input type of functionToTime (same as return type of
     * inputConverter)
     * @param functionToTime a function expecting input of type T, returning
     * output of any type (R)
     * @param inputConverter converts the loop variable to type T and passes it
     * to functionToTime
     * @param name name of the function (used for output)
     */

    public <R, T> void add(Function<T, R> functionToTime, IntFunction<T> inputConverter, String name) {
        Objects.requireNonNull(inputConverter);
        Objects.requireNonNull(functionToTime);
        functionsToTime.add(new TimingObject(functionToTime, inputConverter, name));
    }

    /**
     * sets how many chunks should be run.
     *
     * The total amount of how often the given functions should be run when
     * timed is amountChunks * amountRunsPerChunk.
     *
     * @param amountChunks amountChunks
     */

    public void setAmountChunks(int amountChunks) {
        if (amountChunks < 1) {
            throw new IllegalArgumentException("chunks must be at least 1.");
        }
        this.amountChunks = amountChunks;
    }

    /**
     * sets how often the function is run per chunk.
     *
     * The total amount of how often the given functions should be run when
     * timed is amountChunks * amountRunsPerChunk.
     *
     * @param amountRunsPerChunk amountRunsPerChunk
     */

    public void setAmountRunsPerChunk(int amountRunsPerChunk) {
        if (amountRunsPerChunk < 1) {
            throw new IllegalArgumentException("amountRunsPerChunk must be at least 1.");
        }
        this.amountRunsPerChunk = amountRunsPerChunk;
    }

    /**
     * sets what should be reported at the end of the test.
     *
     * @param report report
     */

    public void setReport(EnumSet<Report> report) {
        Objects.requireNonNull(report);
        this.report = report;
    }


    /**
     * performs the actual timing for all given functions.
     *
     * @param force if set to true, needed memory estimations will be ignored
     * and the tests will be run.
     * @throws Exception if force is set to false, and the estimated needed memory is too large
     */

    public void time(boolean force) throws Exception {
        time(force, null);
    }

    /**
     * performs the actual timing for all given functions.
     *
     * @param force if set to true, needed memory estimations will be ignored
     * and the tests will be run.
     * @param consumer for progress report. may be null
     * @throws Exception if force is set to false, and the estimated needed memory is too large
     */

    public void time(boolean force, Consumer<String> consumer) throws Exception {

        long memoryNeeded = amountChunks * amountRunsPerChunk * functionsToTime.size() * Long.BYTES;
        memoryNeeded *= 4; // lets be pessimistic (estimation is always off by a factor of 3-5; less off the higher the value is)
        if (!force
                && ((memoryNeeded > Integer.MAX_VALUE || memoryNeeded < Integer.MIN_VALUE)
                || memoryNeeded > Runtime.getRuntime().maxMemory())) {
            throw new Exception("Estimated memory consumption of: " + (memoryNeeded / 1048576) + " mb");
        }
        print(consumer, "starting timing. estimated memory consumption: " + (memoryNeeded / 1048576) + " mb");

        for (int chunks = 0; chunks < amountChunks; chunks++) {
            // run a chunk of tests on this timingObject:
            for (TimingObject timingObject : functionsToTime) {
                // generate input:
                ArrayList<Object> input = new ArrayList<>();
                for (int runs = 0; runs < amountRunsPerChunk; runs++) {
                    input.add(timingObject.inputConverter.apply((chunks * amountRunsPerChunk) + runs));
                }
                // run with input:
                long[] times = timeRuns(timingObject, input);
                timingObject.addTimeChunk(times);
            }
            Collections.shuffle(functionsToTime); // randomize functions each time
        }

        long usedMemory = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
        print(consumer, "done timing. currently used memory: " + (usedMemory / 1048576) + " mb");
        print(consumer, "starting evaluation");

        for (TimingObject timingObject : functionsToTime) {
            timingObject.processTimes();
        }

        print(consumer, "finished evaluation");
    }

    private void print(Consumer<String> consumer, String string) {
        if (consumer != null) {
            consumer.accept(string);
        }
    }

    /**
     * runs a chunk of functions, timing each one.
     *
     * @param <T> input type
     * @param timingObject timingObject
     * @param input list of input for functions
     * @return array of times
     */

    private <T> long[] timeRuns(TimingObject timingObject, ArrayList<T> input) {
        long[] times = new long[input.size()];
        for (int i = 0; i < input.size(); i++) {
            long start = System.nanoTime();
            timingObject.function.apply(input.get(i));
            times[i] = System.nanoTime() - start;
        }
        return times;
    }

    /**
     * passes the result of the timing to the given consumer.
     *
     * @param consumer consumer
     * @param sort how to sort the result
     */

    public void output(Consumer<String> consumer, Sort sort) {
        Objects.requireNonNull(consumer);
        Objects.requireNonNull(sort);

        switch (sort) {
            case ASC:
                Collections.sort(functionsToTime,
                        (TimingObject t1, TimingObject t2) -> (int) (t1.meanTime - t2.meanTime));
                break;
            case DESC:
                Collections.sort(functionsToTime,
                        (TimingObject t1, TimingObject t2) -> (int) (t2.meanTime - t1.meanTime));
                break;
            case NAME:
                Collections.sort(functionsToTime,
                        (TimingObject t1, TimingObject t2) -> t1.name.compareTo(t2.name));
                break;
            default:
                break;
        }

        FormatedTableBuilder formater = new FormatedTableBuilder();

        formater.addLine(new String[]{"name", "per call (mean, ns)", "per call (median, ns)", "95th percentile (ns)", "total (ms)", "runs"});

        formater.add(functionsToTime, (TimingObject timing) -> new String[]{
            timing.name,
            String.valueOf(timing.getMeanTime()),
            String.valueOf(timing.getMedianTime()),
            String.valueOf(timing.getPercentile(95)),
            String.valueOf(timing.getMeanTime() * timing.times.size() / 1000000000.0),
            String.valueOf(amountChunks * amountRunsPerChunk),
        });

        if (!report.contains(Report.NAME)) {
            formater.removeColumn("name", 0);
        }
        if (!report.contains(Report.MEAN)) {
            formater.removeColumn("per call (mean, ns)", 0);
        }
        if (!report.contains(Report.MEDIAN)) {
            formater.removeColumn("per call (median, ns)", 0);
        }
        if (!report.contains(Report.PERCENTILE)) {
            formater.removeColumn("95th percentile (ns)", 0);
        }
        if (!report.contains(Report.TOTAL)) {
            formater.removeColumn("total (ms)", 0);
        }
        if (!report.contains(Report.RUNS)) {
            formater.removeColumn("runs", 0);
        }
        consumer.accept(formater.format());
    }

    private class TimingObject {

        private Function function;
        private IntFunction inputConverter;
        private String name;
        private long meanTime;
        private List<Long> times;

        public TimingObject(Function function, IntFunction inputConverter, String name) {
            this.function = function;
            this.inputConverter = inputConverter;
            this.name = name;
            this.times = new ArrayList<>();
        }

        public void addTimeChunk(long[] timeChunk) {
            for (int i = 0; i < timeChunk.length; i++) {
                times.add(timeChunk[i]);
            }
        }

        public void processTimes() {
            Statistics.removeWorstAndBest(times); // also sorts
            meanTime = (long) Statistics.calculateMean(times);
        }

        public long getMeanTime() {
            return meanTime;
        }

        public long getMedianTime() {
            return (long) Statistics.calculateMedian(times);
        }

        public long getPercentile(int percentile) {
            return (long) Statistics.percentile(times, percentile);
        }

    }
}
import java.util.Collections;
import java.util.List;

public class Statistics {

    private static final int REMOVE_BEST_PERCENT = 10;
    private static final int REMOVE_WORST_PERCENT = 10;

    /**
     * removes the x lowest and x highest values from the list.
     *
     * Also sorts the list ascending.
     *
     * @param list list
     */

    public static void removeWorstAndBest(List<Long> list) {
        // sort ascending
        Collections.sort(list,
                (Long l1, Long l2) -> (int) (l1 - l2));

        // remove x worst and x best results
        int originalSize = list.size();
        list.subList(list.size() - (REMOVE_BEST_PERCENT * originalSize / 100),
                originalSize).clear();
        list.subList(0,
                (REMOVE_WORST_PERCENT * originalSize / 100)).clear();
    }

    /**
     * returns the mean of the list.
     *
     * @param list list
     * @return mean
     */

    public static double calculateMean(final List<Long> list) {
        double mean = 0;
        final int length = list.size();
        for (int i = 0; i < length; i++) {
            mean += list.get(i) / (double) length;
        }
        return mean;
    }

    /**
     * returns the median of the list.
     *
     * Expects a sorted list.
     *
     * @param list list
     * @return median
     */

    public static double calculateMedian(List<Long> list) {
        return list.get(list.size() / 2);
    }

    /**
     * returns the percentile.
     *
     * Expects a sorted list.
     *
     * @param list list
     * @param percentile percentile
     * @return percentile
     */

    public static double percentile(List<Long> list, int percentile) {
        int rank = (int) Math.ceil((percentile / 100) * list.size());
        return list.get(rank);
    }
}

Thanks and License

I posted this code for review at CodeReview.SE here as well as a follow-up here and want to thank the reviewers there.

Because of this, the code is licensed under CC BY-SA.

Leave a Reply

Your email address will not be published.