1 package org.djutils.math.functions;
2
3 import java.util.ArrayList;
4 import java.util.Arrays;
5 import java.util.Collections;
6 import java.util.List;
7 import java.util.Objects;
8 import java.util.SortedSet;
9 import java.util.TreeSet;
10
11 import org.djutils.exceptions.Throw;
12
13
14
15
16
17
18
19
20
21
22
23
24
25 public class Sum implements MathFunction
26 {
27
28 private final List<MathFunction> terms;
29
30
31
32
33
34
35
36 public Sum(final MathFunction... functions)
37 {
38 this(Arrays.asList(functions));
39 }
40
41
42
43
44
45
46
47 public Sum(final List<MathFunction> functions)
48 {
49 Throw.when(functions.size() == 0, IllegalArgumentException.class, "Sum needs at least one object to sum");
50 this.terms = simplify(functions);
51 }
52
53
54
55
56
57
58 private List<MathFunction> simplify(final List<MathFunction> functions)
59 {
60 List<MathFunction> result = new ArrayList<>(functions);
61
62
63 for (int index = 0; index < result.size(); index++)
64 {
65 MathFunction function = result.get(index);
66 Throw.whenNull(function, "function");
67 if (function instanceof Sum)
68 {
69
70 result.remove(index);
71 index--;
72 result.addAll(((Sum) function).terms);
73 }
74 }
75
76 for (int index = 0; index < result.size(); index++)
77 {
78 MathFunction function = result.get(index);
79 MathFunction optimized = function.simplify();
80 if (!function.equals(optimized))
81 {
82 result.remove(index);
83 result.add(index, optimized);
84 }
85 }
86 Collections.sort(result);
87
88 for (int index = 0; index < result.size(); index++)
89 {
90 MathFunction function = result.get(index);
91 if (function.equals(Constant.ZERO))
92 {
93 result.remove(index);
94 index--;
95 }
96 else if (index < result.size() - 1)
97 {
98 MathFunction nextFunction = result.get(index + 1);
99 MathFunction merged = function.mergeAdd(nextFunction);
100 if (merged != null)
101 {
102 result.remove(index);
103 result.remove(index);
104 result.add(index, merged);
105 index--;
106 }
107 }
108 }
109 if (result.size() == 0)
110 {
111 result.add(Constant.ZERO);
112 }
113 return result;
114 }
115
116 @Override
117 public double get(final double x)
118 {
119 double result = 0.0;
120 for (MathFunction fi : this.terms)
121 {
122 result += fi.get(x);
123 }
124 return result;
125 }
126
127 @Override
128 public MathFunction getDerivative()
129 {
130 List<MathFunction> derivatives = new ArrayList<>();
131 for (MathFunction term : this.terms)
132 {
133 derivatives.add(term.getDerivative());
134 }
135 return new Sum(derivatives).simplify();
136 }
137
138 @Override
139 public MathFunction simplify()
140 {
141 List<MathFunction> simplifiedTerms = simplify(this.terms);
142 if (simplifiedTerms.size() == 1)
143 {
144 return simplifiedTerms.get(0);
145 }
146 return this;
147 }
148
149 @Override
150 public MathFunction scaleBy(final double factor)
151 {
152 if (factor == 0.0)
153 {
154 return Constant.ZERO;
155 }
156 if (factor == 1.0)
157 {
158 return this;
159 }
160 List<MathFunction> result = new ArrayList<>(this.terms.size());
161 for (MathFunction function : this.terms)
162 {
163 result.add(function.scaleBy(factor));
164 }
165 return new Sum(result);
166 }
167
168 @Override
169 public int sortPriority()
170 {
171 return 101;
172 }
173
174 @Override
175 public int compareWithinSubType(final MathFunction other)
176 {
177 Throw.when(!(other instanceof Sum), IllegalArgumentException.class, "other is of wrong type");
178 Sum otherSum = (Sum) other;
179 for (int index = 0; index < this.terms.size(); index++)
180 {
181 if (index >= otherSum.terms.size())
182 {
183 return 1;
184 }
185 int result = this.terms.get(index).compareTo(otherSum.terms.get(index));
186 if (result != 0)
187 {
188 return result;
189 }
190 }
191 if (otherSum.terms.size() > this.terms.size())
192 {
193 return -1;
194 }
195 return 0;
196 }
197
198 @Override
199 public KnotReport getKnotReport(final Interval<?> interval)
200 {
201 KnotReport result = KnotReport.NONE;
202 for (MathFunction term : this.terms)
203 {
204 result = result.combineWith(term.getKnotReport(interval));
205 }
206 return result;
207 }
208
209 @Override
210 public SortedSet<Double> getKnots(final Interval<?> interval)
211 {
212 SortedSet<Double> result = new TreeSet<>();
213 for (MathFunction term : this.terms)
214 {
215 result.addAll(term.getKnots(interval));
216 }
217 return result;
218 }
219
220 @Override
221 public String toString()
222 {
223 StringBuilder result = new StringBuilder();
224 result.append("\u03A3(");
225 for (int i = 0; i < this.terms.size(); i++)
226 {
227 if (i > 0)
228 {
229 result.append(", ");
230 }
231 result.append(this.terms.get(i).toString());
232 }
233 result.append(")");
234 return result.toString();
235 }
236
237 @Override
238 public int hashCode()
239 {
240 return Objects.hash(this.terms);
241 }
242
243 @SuppressWarnings("checkstyle:needbraces")
244 @Override
245 public boolean equals(final Object obj)
246 {
247 if (this == obj)
248 return true;
249 if (obj == null)
250 return false;
251 if (getClass() != obj.getClass())
252 return false;
253 Sum other = (Sum) obj;
254 return Objects.equals(this.terms, other.terms);
255 }
256
257 }