Course – LS – All

Get started with Spring and Spring Boot, through the Learn Spring course:

>> CHECK OUT THE COURSE

1. Introduction

In this article, we’re going to explore a few different ways to solve the same problem – calculating the weighted mean of a set of values.

2. What Is a Weighted Mean?

We calculate the standard mean of a set of numbers by summing all of the numbers and then dividing this by the count of the numbers. For example, the mean of the numbers 1, 3, 5, 7, 9 will be (1 + 3 + 5 + 7 + 9) / 5, which equals 5.

When we’re calculating a weighted mean, we instead have a set of numbers that each have weights:

Number Weight
1 10
3 20
5 30
7 50
9 40

In this case, we need to take the weights into account. The new calculation is to sum the product of each number with its weight and divide this by the sum of all the weights. For example, here the mean will be ((1 * 10) + (3 * 20) + (5 * 30) + (7 * 50) + (9 * 40)) / (10 + 20 + 30 + 50 + 40), which equals 6.2.

3. Setting Up

For the sake of these examples, we’ll do some initial setup. The most important thing is that we need a type to represent our weighed values:

private static class Values {
    int value;
    int weight;

    public Values(int value, int weight) {
        this.value = value;
        this.weight = weight;
    }
}

In our sample code, we’ll also have an initial set of values and an expected result from our average:

private List<Values> values = Arrays.asList(
    new Values(1, 10),
    new Values(3, 20),
    new Values(5, 30),
    new Values(7, 50),
    new Values(9, 40)
);

private Double expected = 6.2;

4. Two-Pass Calculation

The most obvious way to calculate this is exactly as we saw above. We can iterate over the list of numbers and separately sum the values that we need for our division:

double top = values.stream()
  .mapToDouble(v -> v.value * v.weight)
  .sum();
double bottom = values.stream()
  .mapToDouble(v -> v.weight)
  .sum();

Having done this, our calculation is now just a case of dividing one by the other:

double result = top / bottom;

We can simplify this further by using a traditional for loop instead, and doing the two sums as we go. The downside here is that the results can’t be immutable values:

double top = 0;
double bottom = 0;

for (Values v : values) {
    top += (v.value * v.weight);
    bottom += v.weight;
}

5. Expanding the List

We can think about our weighted average calculation in a different way. Instead of calculating a sum of products, we can instead expand each of the weighted values. For example, we can expand our list to contain 10 copies of “1”, 20 copies of “2”, and so on. At this point, we can do a straight average on the expanded list:

double result = values.stream()
  .flatMap(v -> Collections.nCopies(v.weight, v.value).stream())
  .mapToInt(v -> v)
  .average()
  .getAsDouble();

This is obviously going to be less efficient, but it may also be clearer and easier to understand. We can also more easily do other manipulations on the final set of numbers — for example, finding the median is much easier to understand this way.

6. Reducing the List

We’ve seen that summing the products and weights is more efficient than trying to expand out the values. But what if we want to do this in a single pass without using mutable values? We can achieve this using the reduce() functionality from Streams. In particular, we’ll use this to perform our addition as we go, collecting the running totals into an object as we go.

The first thing we want is a class to collect our running totals into:

class WeightedAverage {
    final double top;
    final double bottom;

    public WeightedAverage(double top, double bottom) {
        this.top = top;
        this.bottom = bottom;
    }

    double average() {
        return top / bottom;
    }
}

We’ve also included an average() function on this that will do our final calculation. Now, we can perform our reduction:

double result = values.stream()
  .reduce(new WeightedAverage(0, 0),
    (acc, next) -> new WeightedAverage(
      acc.top + (next.value * next.weight),
      acc.bottom + next.weight),
    (left, right) -> new WeightedAverage(
      left.top + right.top,
      left.bottom + right.bottom))
  .average();

This looks very complicated, so let’s break it down into parts.

The first parameter to reduce() is our identity. This is the weighted average with values of 0.

The next parameter is a lambda that takes a WeightedAverage instance and adds the next value to it. We’ll notice that our sum here is calculated in the same way as what we performed earlier.

The final parameter is a lambda for combining two WeightedAverage instances. This is necessary for certain cases with reduce(), such as if we were doing this on a parallel stream.

The result of the reduce() call is then a WeightedAverage instance that we can use to calculate our result.

7. Custom Collectors

Our reduce() version is certainly clean, but it’s harder to understand than our other attempts. We’ve ended up with two lambdas being passed into the function, and still needing to perform a post-processing step to calculate the average.

One final solution that we can explore is writing a custom collector to encapsulate this work. This will directly produce our result, and it’ll be much simpler to use.

Before we write our collector, let’s look at the interface we need to implement:

public interface Collector<T, A, R> {
    Supplier<A> supplier();
    BiConsumer<A, T> accumulator();
    BinaryOperator<A> combiner();
    Function<A, R> finisher();
    Set<Characteristics> characteristics();
}

There’s a lot going on here, but we’ll work through it as we build our collector. We’ll also see how some of this extra complexity allows us to use the exact same collector on a parallel stream instead of only on a sequential stream.

The first thing to note is the generic types:

  • T – This is the input type. Our collector always needs to be tied to the type of values that it can collect.
  • R – This is the result type. Our collector always needs to specify the type it will produce.
  • A – This is the aggregation type. This is typically internal to the collector but is necessary for some of the function signatures.

This means that we need to define an aggregation type. This is just a type that collects a running result as we’re going. We can’t just do this directly in our collector because we need to be able to support parallel streams, where there might be an unknown number of these going on at once. As such, we define a separate type that stores the results from each parallel stream:

class RunningTotals {
    double top;
    double bottom;

    public RunningTotals() {
        this.top = 0;
        this.bottom = 0;
    }
}

This is a mutable type, but because its use will be constrained to one parallel stream, that’s okay.

Now, we can implement our collector methods. We’ll notice that most of these return lambdas. Again, this is to support parallel streams where the underlying streams framework will call some combination of them as appropriate.

The first method is supplier(). This constructs a new, zero instance of our RunningTotals:

@Override
public Supplier<RunningTotals> supplier() {
    return RunningTotals::new;
}

Next, we have accumulator(). This takes a RunningTotals instance and the next Values instance to process and combines them, updating our RunningTotals instance in place:

@Override
public BiConsumer<RunningTotals, Values> accumulator() {
    return (current, next) -> {
        current.top += (next.value * next.weight);
        current.bottom += next.weight;
    };
}

Our next method is combiner(). This takes two RunningTotals instances – from different parallel streams – and combines them into one:

@Override
public BinaryOperator<RunningTotals> combiner() {
    return (left, right) -> {
        left.top += right.top;
        left.bottom += right.bottom;

        return left;
    };
}

In this case, we’re mutating one of our inputs and directly returning that. This is perfectly safe, but we can also return a new instance if that’s easier.

This will only be used if the JVM decides to split the stream processing into multiple parallel streams, which depends on several factors. However, we should implement it in case this does ever happen.

The final lambda method that we need to implement is finisher(). This takes the final RunningTotals instance that is left after all of the values have been accumulated and all of the parallel streams have been combined, and returns the final result:

@Override
public Function<RunningTotals, Double> finisher() {
    return rt -> rt.top / rt.bottom;
}

Our Collector also needs a characteristics() method that returns a set of characteristics describing how the collector can be used. The Collectors.Characteristics enum consists of three values:

  • CONCURRENT – The accumulator() function is safe to call on the same aggregation instance from parallel threads. If this is specified, then the combiner() function will never be used, but the aggregation() function must take extra care.
  • UNORDERED – The collector can safely process the elements from the underlying stream in any order. If this isn’t specified, then, where possible, the values will be provided in the correct order.
  • IDENTITY_FINISH – The finisher() function just directly returns its input. If this is specified, then the collection process may short-circuit this call and just return the value directly.

In our case, we have an UNORDERED collector but need to omit the other two:

@Override
public Set<Characteristics> characteristics() {
    return Collections.singleton(Characteristics.UNORDERED);
}

We’re now ready to use our collector:

double result = values.stream().collect(new WeightedAverage());

While writing the collector is much more complicated than before, using it is significantly easier. We can also leverage things like parallel streams with no extra work, meaning that this gives us an easier-to-use and more powerful solution, assuming that we need to reuse it.

8. Conclusion

Here, we’ve seen several different ways that we can calculate a weighted average of a set of values, ranging from simply looping over the values ourselves to writing a full Collector instance that can be reused whenever we need to perform this calculation. Next time you need to do this, why not give one of these a go?

As always, the full code for this article is available over on GitHub.

Course – LS – All

Get started with Spring and Spring Boot, through the Learn Spring course:

>> CHECK OUT THE COURSE
res – REST with Spring (eBook) (everywhere)
Comments are open for 30 days after publishing a post. For any issues past this date, use the Contact form on the site.