001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.analysis.interpolation;
018    
019    import java.io.Serializable;
020    import java.util.Arrays;
021    
022    import org.apache.commons.math.MathException;
023    import org.apache.commons.math.analysis.polynomials.PolynomialSplineFunction;
024    
025    /**
026     * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
027     * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
028     * real univariate functions.
029     * <p/>
030     * For reference, see
031     * <a href="http://www.math.tau.ac.il/~yekutiel/MA seminar/Cleveland 1979.pdf">
032     * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
033     * Scatterplots</a>
034     * <p/>
035     * This class implements both the loess method and serves as an interpolation
036     * adapter to it, allowing to build a spline on the obtained loess fit.
037     *
038     * @version $Revision: 794709 $ $Date: 2009-07-16 11:09:02 -0400 (Thu, 16 Jul 2009) $
039     * @since 2.0
040     */
041    public class LoessInterpolator
042            implements UnivariateRealInterpolator, Serializable {
043    
044        /** serializable version identifier. */
045        private static final long serialVersionUID = 5204927143605193821L;
046    
047        /**
048         * Default value of the bandwidth parameter.
049         */
050        public static final double DEFAULT_BANDWIDTH = 0.3;
051        /**
052         * Default value of the number of robustness iterations.
053         */
054        public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
055    
056        /**
057         * The bandwidth parameter: when computing the loess fit at
058         * a particular point, this fraction of source points closest
059         * to the current point is taken into account for computing
060         * a least-squares regression.
061         * <p/>
062         * A sensible value is usually 0.25 to 0.5.
063         */
064        private final double bandwidth;
065    
066        /**
067         * The number of robustness iterations parameter: this many
068         * robustness iterations are done.
069         * <p/>
070         * A sensible value is usually 0 (just the initial fit without any
071         * robustness iterations) to 4.
072         */
073        private final int robustnessIters;
074    
075        /**
076         * Constructs a new {@link LoessInterpolator}
077         * with a bandwidth of {@link #DEFAULT_BANDWIDTH} and
078         * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations.
079         * See {@link #LoessInterpolator(double, int)} for an explanation of
080         * the parameters.
081         */
082        public LoessInterpolator() {
083            this.bandwidth = DEFAULT_BANDWIDTH;
084            this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
085        }
086    
087        /**
088         * Constructs a new {@link LoessInterpolator}
089         * with given bandwidth and number of robustness iterations.
090         *
091         * @param bandwidth  when computing the loess fit at
092         * a particular point, this fraction of source points closest
093         * to the current point is taken into account for computing
094         * a least-squares regression.</br>
095         * A sensible value is usually 0.25 to 0.5, the default value is
096         * {@link #DEFAULT_BANDWIDTH}.
097         * @param robustnessIters This many robustness iterations are done.</br>
098         * A sensible value is usually 0 (just the initial fit without any
099         * robustness iterations) to 4, the default value is
100         * {@link #DEFAULT_ROBUSTNESS_ITERS}.
101         * @throws MathException if bandwidth does not lie in the interval [0,1]
102         * or if robustnessIters is negative.
103         */
104        public LoessInterpolator(double bandwidth, int robustnessIters) throws MathException {
105            if (bandwidth < 0 || bandwidth > 1) {
106                throw new MathException("bandwidth must be in the interval [0,1], but got {0}",
107                                        bandwidth);
108            }
109            this.bandwidth = bandwidth;
110            if (robustnessIters < 0) {
111                throw new MathException("the number of robustness iterations must " +
112                                        "be non-negative, but got {0}",
113                                        robustnessIters);
114            }
115            this.robustnessIters = robustnessIters;
116        }
117    
118        /**
119         * Compute an interpolating function by performing a loess fit
120         * on the data at the original abscissae and then building a cubic spline
121         * with a
122         * {@link org.apache.commons.math.analysis.interpolation.SplineInterpolator}
123         * on the resulting fit.
124         *
125         * @param xval the arguments for the interpolation points
126         * @param yval the values for the interpolation points
127         * @return A cubic spline built upon a loess fit to the data at the original abscissae
128         * @throws MathException  if some of the following conditions are false:
129         * <ul>
130         * <li> Arguments and values are of the same size that is greater than zero</li>
131         * <li> The arguments are in a strictly increasing order</li>
132         * <li> All arguments and values are finite real numbers</li>
133         * </ul>
134         */
135        public final PolynomialSplineFunction interpolate(
136                final double[] xval, final double[] yval) throws MathException {
137            return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
138        }
139    
140        /**
141         * Compute a loess fit on the data at the original abscissae.
142         *
143         * @param xval the arguments for the interpolation points
144         * @param yval the values for the interpolation points
145         * @return values of the loess fit at corresponding original abscissae
146         * @throws MathException if some of the following conditions are false:
147         * <ul>
148         * <li> Arguments and values are of the same size that is greater than zero</li>
149         * <li> The arguments are in a strictly increasing order</li>
150         * <li> All arguments and values are finite real numbers</li>
151         * </ul>
152         */
153        public final double[] smooth(final double[] xval, final double[] yval)
154                throws MathException {
155            if (xval.length != yval.length) {
156                throw new MathException(
157                        "Loess expects the abscissa and ordinate arrays " +
158                        "to be of the same size, " +
159                        "but got {0} abscisssae and {1} ordinatae",
160                        xval.length, yval.length);
161            }
162    
163            final int n = xval.length;
164    
165            if (n == 0) {
166                throw new MathException("Loess expects at least 1 point");
167            }
168    
169            checkAllFiniteReal(xval, true);
170            checkAllFiniteReal(yval, false);
171    
172            checkStrictlyIncreasing(xval);
173    
174            if (n == 1) {
175                return new double[]{yval[0]};
176            }
177    
178            if (n == 2) {
179                return new double[]{yval[0], yval[1]};
180            }
181    
182            int bandwidthInPoints = (int) (bandwidth * n);
183    
184            if (bandwidthInPoints < 2) {
185                throw new MathException(
186                        "the bandwidth must be large enough to " +
187                        "accomodate at least 2 points. There are {0} " +
188                        " data points, and bandwidth must be at least {1} " +
189                        " but it is only {2}",
190                        n, 2.0 / n, bandwidth);
191            }
192    
193            final double[] res = new double[n];
194    
195            final double[] residuals = new double[n];
196            final double[] sortedResiduals = new double[n];
197    
198            final double[] robustnessWeights = new double[n];
199    
200            // Do an initial fit and 'robustnessIters' robustness iterations.
201            // This is equivalent to doing 'robustnessIters+1' robustness iterations
202            // starting with all robustness weights set to 1.
203            Arrays.fill(robustnessWeights, 1);
204    
205            for (int iter = 0; iter <= robustnessIters; ++iter) {
206                final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
207                // At each x, compute a local weighted linear regression
208                for (int i = 0; i < n; ++i) {
209                    final double x = xval[i];
210    
211                    // Find out the interval of source points on which
212                    // a regression is to be made.
213                    if (i > 0) {
214                        updateBandwidthInterval(xval, i, bandwidthInterval);
215                    }
216    
217                    final int ileft = bandwidthInterval[0];
218                    final int iright = bandwidthInterval[1];
219    
220                    // Compute the point of the bandwidth interval that is
221                    // farthest from x
222                    final int edge;
223                    if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
224                        edge = ileft;
225                    } else {
226                        edge = iright;
227                    }
228    
229                    // Compute a least-squares linear fit weighted by
230                    // the product of robustness weights and the tricube
231                    // weight function.
232                    // See http://en.wikipedia.org/wiki/Linear_regression
233                    // (section "Univariate linear case")
234                    // and http://en.wikipedia.org/wiki/Weighted_least_squares
235                    // (section "Weighted least squares")
236                    double sumWeights = 0;
237                    double sumX = 0, sumXSquared = 0, sumY = 0, sumXY = 0;
238                    double denom = Math.abs(1.0 / (xval[edge] - x));
239                    for (int k = ileft; k <= iright; ++k) {
240                        final double xk = xval[k];
241                        final double yk = yval[k];
242                        double dist;
243                        if (k < i) {
244                            dist = (x - xk);
245                        } else {
246                            dist = (xk - x);
247                        }
248                        final double w = tricube(dist * denom) * robustnessWeights[k];
249                        final double xkw = xk * w;
250                        sumWeights += w;
251                        sumX += xkw;
252                        sumXSquared += xk * xkw;
253                        sumY += yk * w;
254                        sumXY += yk * xkw;
255                    }
256    
257                    final double meanX = sumX / sumWeights;
258                    final double meanY = sumY / sumWeights;
259                    final double meanXY = sumXY / sumWeights;
260                    final double meanXSquared = sumXSquared / sumWeights;
261    
262                    final double beta;
263                    if (meanXSquared == meanX * meanX) {
264                        beta = 0;
265                    } else {
266                        beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
267                    }
268    
269                    final double alpha = meanY - beta * meanX;
270    
271                    res[i] = beta * x + alpha;
272                    residuals[i] = Math.abs(yval[i] - res[i]);
273                }
274    
275                // No need to recompute the robustness weights at the last
276                // iteration, they won't be needed anymore
277                if (iter == robustnessIters) {
278                    break;
279                }
280    
281                // Recompute the robustness weights.
282    
283                // Find the median residual.
284                // An arraycopy and a sort are completely tractable here, 
285                // because the preceding loop is a lot more expensive
286                System.arraycopy(residuals, 0, sortedResiduals, 0, n);
287                Arrays.sort(sortedResiduals);
288                final double medianResidual = sortedResiduals[n / 2];
289    
290                if (medianResidual == 0) {
291                    break;
292                }
293    
294                for (int i = 0; i < n; ++i) {
295                    final double arg = residuals[i] / (6 * medianResidual);
296                    robustnessWeights[i] = (arg >= 1) ? 0 : Math.pow(1 - arg * arg, 2);
297                }
298            }
299    
300            return res;
301        }
302    
303        /**
304         * Given an index interval into xval that embraces a certain number of
305         * points closest to xval[i-1], update the interval so that it embraces
306         * the same number of points closest to xval[i]
307         *
308         * @param xval arguments array
309         * @param i the index around which the new interval should be computed
310         * @param bandwidthInterval a two-element array {left, right} such that: <p/>
311         * <tt>(left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])</tt>
312         * <p/> and also <p/>
313         * <tt>(right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])</tt>.
314         * The array will be updated.
315         */
316        private static void updateBandwidthInterval(final double[] xval, final int i,
317                                                    final int[] bandwidthInterval) {
318            final int left = bandwidthInterval[0];
319            final int right = bandwidthInterval[1];
320    
321            // The right edge should be adjusted if the next point to the right
322            // is closer to xval[i] than the leftmost point of the current interval
323            if (right < xval.length - 1 &&
324               xval[right+1] - xval[i] < xval[i] - xval[left]) {
325                bandwidthInterval[0]++;
326                bandwidthInterval[1]++;
327            }
328        }
329    
330        /**
331         * Compute the 
332         * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
333         * weight function
334         *
335         * @param x the argument
336         * @return (1-|x|^3)^3
337         */
338        private static double tricube(final double x) {
339            final double tmp = 1 - x * x * x;
340            return tmp * tmp * tmp;
341        }
342    
343        /**
344         * Check that all elements of an array are finite real numbers.
345         *
346         * @param values the values array
347         * @param isAbscissae if true, elements are abscissae otherwise they are ordinatae
348         * @throws MathException if one of the values is not
349         *         a finite real number
350         */
351        private static void checkAllFiniteReal(final double[] values, final boolean isAbscissae)
352            throws MathException {
353            for (int i = 0; i < values.length; i++) {
354                final double x = values[i];
355                if (Double.isInfinite(x) || Double.isNaN(x)) {
356                    final String pattern = isAbscissae ?
357                            "all abscissae must be finite real numbers, but {0}-th is {1}" :
358                            "all ordinatae must be finite real numbers, but {0}-th is {1}";
359                    throw new MathException(pattern, i, x);
360                }
361            }
362        }
363    
364        /**
365         * Check that elements of the abscissae array are in a strictly
366         * increasing order.
367         *
368         * @param xval the abscissae array
369         * @throws MathException if the abscissae array
370         * is not in a strictly increasing order
371         */
372        private static void checkStrictlyIncreasing(final double[] xval)
373            throws MathException {
374            for (int i = 0; i < xval.length; ++i) {
375                if (i >= 1 && xval[i - 1] >= xval[i]) {
376                    throw new MathException(
377                            "the abscissae array must be sorted in a strictly " +
378                            "increasing order, but the {0}-th element is {1} " +
379                            "whereas {2}-th is {3}",
380                            i - 1, xval[i - 1], i, xval[i]);
381                }
382            }
383        }
384    }