Concatenation.java
package org.djutils.math.functions;
import java.util.Iterator;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import org.djutils.exceptions.Throw;
/**
* Concatenate FunctionInterface objects
* <p>
* Copyright (c) 2024-2025 Delft University of Technology, Jaffalaan 5, 2628 BX Delft, the Netherlands. All rights reserved. See
* for project information <a href="https://djutils.org" target="_blank"> https://djutils.org</a>. The DJUTILS project is
* distributed under a three-clause BSD-style license, which can be found at
* <a href="https://djutils.org/docs/license.html" target="_blank"> https://djutils.org/docs/license.html</a>.
* </p>
* @author <a href="https://github.com/averbraeck">Alexander Verbraeck</a>
* @author <a href="https://github.com/peter-knoppers">Peter Knoppers</a>
* @author <a href="https://github.com/wjschakel">Wouter Schakel</a>
*/
public class Concatenation implements MathFunction
{
/** The wrapped functions. */
private SortedSet<Interval<MathFunction>> functions;
/**
* Construct the concatenation of one or more MathFunction objects.
* @param intervals the functions and the domains over which they should be active
*/
@SafeVarargs
public Concatenation(final Interval<MathFunction>... intervals)
{
this(convertToSortedSet(intervals));
}
/**
* Convert an array of Interval<FunctionInterface> to a SortedSet.
* @param intervals the intervals
* @return sorted set
*/
@SafeVarargs
private static SortedSet<Interval<MathFunction>> convertToSortedSet(final Interval<MathFunction>... intervals)
{
SortedSet<Interval<MathFunction>> result = new TreeSet<>();
for (var interval : intervals)
{
result.add(interval);
}
return result;
}
/**
* Construct a Concatenation from a sorted set of Interval<MathFunction>.
* @param set the sorted set of Interval with MathFunction payload
*/
public Concatenation(final SortedSet<Interval<MathFunction>> set)
{
// Run the ordered list and check for overlaps and add NaN functions where there are gaps
Interval<MathFunction> prevInterval = null;
for (var interval : set)
{
Interval<MathFunction> thisInterval = interval;
if (prevInterval != null)
{
Throw.when(!prevInterval.disjunct(thisInterval), IllegalArgumentException.class,
"Overlapping domains not permitted");
if (prevInterval.high() < thisInterval.low() || (prevInterval.high() == thisInterval.low()
&& (!prevInterval.highInclusive()) && (!thisInterval.lowInclusive())))
{
// There is a gap; fill it with a NaN function
set.add(new Interval<MathFunction>(prevInterval.high(), !prevInterval.highInclusive(), thisInterval.low(),
!thisInterval.lowInclusive(), Nan.NAN));
}
}
prevInterval = thisInterval;
}
Throw.when(set.size() < 1, IllegalArgumentException.class, "need at least one argument");
this.functions = set;
}
@Override
public Double apply(final Double x)
{
// TODO Use bisection to home in on the interval that covers x; for now use linear search
for (var interval : this.functions)
{
if (interval.covers(x))
{
return interval.payload().apply(x);
}
}
throw new IllegalArgumentException(String.format("x is outside the combined domain of this Concatenation", x));
}
@Override
public Concatenation getDerivative()
{
SortedSet<Interval<MathFunction>> set = new TreeSet<>();
for (var interval : this.functions)
{
set.add(new Interval<MathFunction>(interval.low(), interval.lowInclusive(), interval.high(),
interval.highInclusive(), interval.payload().getDerivative()));
}
return new Concatenation(set);
}
@Override
public MathFunction scaleBy(final double factor)
{
if (factor == 1.0)
{
return this;
}
SortedSet<Interval<MathFunction>> result = new TreeSet<>();
for (Interval<MathFunction> interval : this.functions)
{
result.add(new Interval<MathFunction>(interval.low(), interval.lowInclusive(), interval.high(),
interval.highInclusive(), interval.payload().scaleBy(factor)));
}
return new Concatenation(result);
}
@Override
public int sortPriority()
{
return 110;
}
/**
* Construct a concatenation that is piecewise linear through a given set of points.
* @param map mapping from domain to value at the inflection points
* @return new Concatenation that is piecewise linear and connects the given points
* @throws IllegalArgumentException when <code>map</code> contains fewer than 2 entries
*/
public static Concatenation continuousPiecewiseLinear(final SortedMap<Double, Double> map)
{
SortedSet<Interval<MathFunction>> intervals = new TreeSet<>();
Entry<Double, Double> prevEntry = null;
for (Entry<Double, Double> nextEntry : map.entrySet())
{
if (prevEntry != null)
{
// create one linear section
double slope = (nextEntry.getValue() - prevEntry.getValue()) / (nextEntry.getKey() - prevEntry.getKey());
Power powerFunction = new Power(slope, 1);
double constant = prevEntry.getValue() - powerFunction.apply(prevEntry.getKey());
MathFunction function = new Sum(new Constant(constant), powerFunction);
intervals.add(new Interval<MathFunction>(prevEntry.getKey(), intervals.isEmpty(), nextEntry.getKey(), true,
function));
}
prevEntry = nextEntry;
}
Throw.when(intervals.isEmpty(), IllegalArgumentException.class, "need at least two points");
return new Concatenation(intervals);
}
/**
* Construct a concatenation that is piecewise linear through a given set of input-output pairs.
* @param arguments the input-output pairs; these specify the inflection points
* @return new Concatenation that is piecewise linear and connects the given points
* @throws IllegalArgumentException when <code>arguments</code> contains an odd number of entries, or fewer than 2 domain
* values, or duplicate domain values with differing function values
*/
public static Concatenation continuousPiecewiseLinear(final double... arguments)
{
Throw.when(arguments.length % 2 != 0, IllegalArgumentException.class, "need an even number of arguments");
SortedMap<Double, Double> map = new TreeMap<>();
for (int i = 0; i < arguments.length; i += 2)
{
Throw.when(map.containsKey(arguments[i]) && arguments[i + 1] != map.get(arguments[i]),
IllegalArgumentException.class, "duplicate domain value with different function value is not permitted");
map.put(arguments[i], arguments[i + 1]);
}
return continuousPiecewiseLinear(map);
}
@Override
public int compareWithinSubType(final MathFunction other)
{
Throw.when(!(other instanceof Concatenation), IllegalArgumentException.class, "other is of wrong type");
return 0;
}
/**
* Report all non-continuities and all points where <code>this</code> function is non differentiable, or non-evaluable. If
* another <code>MathFunction</code> is chained, the transformation of that function, nor any discontinuities of that
* <code>MathFunction</code> are taken into account as there is (currently) no way to figure out what values of the domain
* of the chained function result in values that correspond to the discontinuities of <code>this</code> function.
* @param interval the interval on which to report the discontinuities
* @return iterator that will generate all discontinuities in the interval
*/
public Iterator<Interval<Discontinuity>> discontinuities(final Interval<?> interval)
{
return new Iterator<Interval<Discontinuity>>()
{
/** The interval over which the discontinuities were requested. */
private Interval<?> requestedInterval = interval;
/** Iterator that visits all the internal intervals/functions of the Concatenation in sequence. */
private Iterator<Interval<MathFunction>> internalIterator = Concatenation.this.functions.iterator();
/** The current interval (made available by the hasNext method and cleared by the next method). */
private Interval<MathFunction> currentInterval = null;
@Override
public boolean hasNext()
{
if (this.currentInterval == null && (!this.internalIterator.hasNext()))
{
return false; // out of data
}
while (this.currentInterval == null && this.internalIterator.hasNext())
{
this.currentInterval = this.internalIterator.next().intersection(this.requestedInterval);
}
return this.currentInterval != null;
}
@Override
public Interval<Discontinuity> next()
{
Throw.when(this.currentInterval == null, NoSuchElementException.class, "Out of data");
Interval<Discontinuity> result = this.currentInterval.payload() instanceof Nan
? new Interval<>(this.currentInterval.low(), true, this.currentInterval.high(),
this.currentInterval.highInclusive(), Discontinuity.GAP)
: new Interval<>(this.currentInterval.low(), true, this.currentInterval.low(), true,
Discontinuity.KNOT);
this.currentInterval = null;
return result;
}
};
}
@Override
public KnotReport getKnotReport(final Interval<?> interval)
{
KnotReport result = KnotReport.NONE;
if (this.functions.first().low() > interval.low() || (this.functions.last().high() < interval.high()))
{
result = KnotReport.KNOWN_INFINITE;
}
for (Interval<MathFunction> i : this.functions)
{
Interval<MathFunction> intersection = i.intersection(interval);
if (intersection != null)
{
if (interval.covers(i.low()))
{
result = result.combineWith(KnotReport.KNOWN_FINITE);
}
if (interval.covers(i.high()))
{
result = result.combineWith(KnotReport.KNOWN_FINITE);
}
result = result.combineWith(i.payload().getKnotReport(interval));
}
}
return result;
}
@Override
public SortedSet<Double> getKnots(final Interval<?> interval)
{
Throw.when(this.functions.first().low() > interval.low() || (this.functions.last().high() < interval.high()),
UnsupportedOperationException.class, "Concatentation is undefined over (part of) " + interval);
SortedSet<Double> result = new TreeSet<Double>();
for (Interval<MathFunction> i : this.functions)
{
Interval<MathFunction> intersection = i.intersection(interval);
if (intersection != null)
{
result.addAll(i.payload().getKnots(interval));
}
if (interval.covers(i.low()))
{
result.add(i.low());
}
if (interval.covers(i.high()))
{
result.add(i.high());
}
}
return result;
}
@Override
public String toString()
{
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("IntervalSet(");
boolean first = true;
for (var interval : this.functions)
{
if (!first)
{
stringBuilder.append(", ");
}
stringBuilder.append(interval.toString());
first = false;
}
stringBuilder.append(")");
return stringBuilder.toString();
}
@Override
public int hashCode()
{
return Objects.hash(this.functions);
}
@SuppressWarnings("checkstyle:needbraces")
@Override
public boolean equals(final Object obj)
{
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Concatenation other = (Concatenation) obj;
return Objects.equals(this.functions, other.functions);
}
/** The various discontinuities reported by the <code>discontinuities</code> method. */
enum Discontinuity
{
/** Continuous, but not differentiable. */
KNOT,
/** Not continuous (and, therefore, not differentiable). */
DISCONTINUOUS,
/** Function undefined in this interval; the <code>MathFunction</code> will yield <code>NaN</code> in this interval. */
GAP;
}
}