Sum.java
package org.djutils.math.functions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.SortedSet;
import java.util.TreeSet;
import org.djutils.exceptions.Throw;
/**
* Add up one or more MathFunction objects.
* <p>
* Copyright (c) 2024-2025 Delft University of Technology, Jaffalaan 5, 2628 BX Delft, the Netherlands. All rights reserved. See
* for project information <a href="https://djutils.org" target="_blank"> https://djutils.org</a>. The DJUTILS project is
* distributed under a three-clause BSD-style license, which can be found at
* <a href="https://djutils.org/docs/license.html" target="_blank"> https://djutils.org/docs/license.html</a>.
* </p>
* @author <a href="https://github.com/averbraeck">Alexander Verbraeck</a>
* @author <a href="https://github.com/peter-knoppers">Peter Knoppers</a>
* @author <a href="https://github.com/wjschakel">Wouter Schakel</a>
*/
public class Sum implements MathFunction
{
/** The functions whose values will be summed. */
private final List<MathFunction> terms;
/**
* Construct the sum of one or more functions.
* @param functions the functions that this Sum will add together.
* @throws IllegalArgumentException when zero parameters are provided
* @throws NullPointerException when a <code>null</code> value is among the arguments
*/
public Sum(final MathFunction... functions)
{
this(Arrays.asList(functions));
}
/**
* Construct the sum of one or more functions.
* @param functions the functions that this Sum will add together.
* @throws IllegalArgumentException when zero parameters are provided
* @throws NullPointerException when a <code>null</code> value is among the arguments
*/
public Sum(final List<MathFunction> functions)
{
Throw.when(functions.size() == 0, IllegalArgumentException.class, "Sum needs at least one object to sum");
this.terms = simplify(functions);
}
/**
* Simplify a set of terms that must be added together.
* @param functions the terms that must be added together
* @return minimal array with the remaining terms
*/
private List<MathFunction> simplify(final List<MathFunction> functions)
{
List<MathFunction> result = new ArrayList<>(functions);
// Pull up all Sums that are embedded in this Sum
for (int index = 0; index < result.size(); index++)
{
MathFunction function = result.get(index);
Throw.whenNull(function, "function");
if (function instanceof Sum)
{
// Replace any embedded Sum by all terms that comprise that Sum
result.remove(index);
index--;
result.addAll(((Sum) function).terms);
}
}
// Optimize all elements
for (int index = 0; index < result.size(); index++)
{
MathFunction function = result.get(index);
MathFunction optimized = function.simplify();
if (!function.equals(optimized))
{
result.remove(index);
result.add(index, optimized);
}
}
Collections.sort(result);
// Merge all functions that can be merged
for (int index = 0; index < result.size(); index++)
{
MathFunction function = result.get(index);
if (function.equals(Constant.ZERO))
{
result.remove(index);
index--;
}
else if (index < result.size() - 1)
{
MathFunction nextFunction = result.get(index + 1);
MathFunction merged = function.mergeAdd(nextFunction);
if (merged != null)
{
result.remove(index);
result.remove(index);
result.add(index, merged);
index--; // try to merge it with yet one more MathFunction
}
}
}
if (result.size() == 0)
{
result.add(Constant.ZERO);
}
return result;
}
@Override
public Double apply(final Double x)
{
double result = 0.0;
for (MathFunction fi : this.terms)
{
result += fi.apply(x);
}
return result;
}
@Override
public MathFunction getDerivative()
{
List<MathFunction> derivatives = new ArrayList<>();
for (MathFunction term : this.terms)
{
derivatives.add(term.getDerivative());
}
return new Sum(derivatives).simplify();
}
@Override
public MathFunction simplify()
{
List<MathFunction> simplifiedTerms = simplify(this.terms);
if (simplifiedTerms.size() == 1)
{
return simplifiedTerms.get(0);
}
return this;
}
@Override
public MathFunction scaleBy(final double factor)
{
if (factor == 0.0)
{
return Constant.ZERO;
}
if (factor == 1.0)
{
return this;
}
List<MathFunction> result = new ArrayList<>(this.terms.size());
for (MathFunction function : this.terms)
{
result.add(function.scaleBy(factor));
}
return new Sum(result);
}
@Override
public int sortPriority()
{
return 101;
}
@Override
public int compareWithinSubType(final MathFunction other)
{
Throw.when(!(other instanceof Sum), IllegalArgumentException.class, "other is of wrong type");
Sum otherSum = (Sum) other;
for (int index = 0; index < this.terms.size(); index++)
{
if (index >= otherSum.terms.size())
{
return 1;
}
int result = this.terms.get(index).compareTo(otherSum.terms.get(index));
if (result != 0)
{
return result;
}
}
if (otherSum.terms.size() > this.terms.size())
{
return -1;
}
return 0;
}
@Override
public KnotReport getKnotReport(final Interval<?> interval)
{
KnotReport result = KnotReport.NONE;
for (MathFunction term : this.terms)
{
result = result.combineWith(term.getKnotReport(interval));
}
return result;
}
@Override
public SortedSet<Double> getKnots(final Interval<?> interval)
{
SortedSet<Double> result = new TreeSet<>();
for (MathFunction term : this.terms)
{
result.addAll(term.getKnots(interval));
}
return result;
}
@Override
public String toString()
{
StringBuilder result = new StringBuilder();
result.append("\u03A3("); // Capital sigma (Σ)
for (int i = 0; i < this.terms.size(); i++)
{
if (i > 0)
{
result.append(", ");
}
result.append(this.terms.get(i).toString());
}
result.append(")");
return result.toString();
}
@Override
public int hashCode()
{
return Objects.hash(this.terms);
}
@SuppressWarnings("checkstyle:needbraces")
@Override
public boolean equals(final Object obj)
{
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Sum other = (Sum) obj;
return Objects.equals(this.terms, other.terms);
}
}