1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.math3.analysis.differentiation; 19 20 import org.apache.commons.math3.TestUtils; 21 import org.apache.commons.math3.exception.DimensionMismatchException; 22 import org.apache.commons.math3.exception.MathIllegalArgumentException; 23 import org.apache.commons.math3.util.FastMath; 24 import org.junit.Test; 25 26 27 /** 28 * Test for class {@link GradientFunction}. 29 */ 30 public class GradientFunctionTest { 31 32 @Test test2DDistance()33 public void test2DDistance() { 34 EuclideanDistance f = new EuclideanDistance(); 35 GradientFunction g = new GradientFunction(f); 36 for (double x = -10; x < 10; x += 0.5) { 37 for (double y = -10; y < 10; y += 0.5) { 38 double[] point = new double[] { x, y }; 39 TestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15); 40 } 41 } 42 } 43 44 @Test test3DDistance()45 public void test3DDistance() { 46 EuclideanDistance f = new EuclideanDistance(); 47 GradientFunction g = new GradientFunction(f); 48 for (double x = -10; x < 10; x += 0.5) { 49 for (double y = -10; y < 10; y += 0.5) { 50 for (double z = -10; z < 10; z += 0.5) { 51 double[] point = new double[] { x, y, z }; 52 TestUtils.assertEquals(f.gradient(point), g.value(point), 1.0e-15); 53 } 54 } 55 } 56 } 57 58 private static class EuclideanDistance implements MultivariateDifferentiableFunction { 59 value(double[] point)60 public double value(double[] point) { 61 double d2 = 0; 62 for (double x : point) { 63 d2 += x * x; 64 } 65 return FastMath.sqrt(d2); 66 } 67 value(DerivativeStructure[] point)68 public DerivativeStructure value(DerivativeStructure[] point) 69 throws DimensionMismatchException, MathIllegalArgumentException { 70 DerivativeStructure d2 = point[0].getField().getZero(); 71 for (DerivativeStructure x : point) { 72 d2 = d2.add(x.multiply(x)); 73 } 74 return d2.sqrt(); 75 } 76 gradient(double[] point)77 public double[] gradient(double[] point) { 78 double[] gradient = new double[point.length]; 79 double d = value(point); 80 for (int i = 0; i < point.length; ++i) { 81 gradient[i] = point[i] / d; 82 } 83 return gradient; 84 } 85 86 } 87 88 } 89