
Check analytic gradients of a function using finite difference approximations
Source:R/check.derivatives.R
check.derivatives.Rd
This function compares the analytic gradients of a function with a finite difference approximation and prints the results of these checks.
Usage
check.derivatives(
.x,
func,
func_grad,
check_derivatives_tol = 1e-04,
check_derivatives_print = "all",
func_grad_name = "grad_f",
...
)
Arguments
- .x
point at which the comparison is done.
- func
function to be evaluated.
- func_grad
function calculating the analytic gradients.
- check_derivatives_tol
option determining when differences between the analytic gradient and its finite difference approximation are flagged as an error.
- check_derivatives_print
option related to the amount of output. 'all' means that all comparisons are shown, 'errors' only shows comparisons that are flagged as an error, and 'none' shows the number of errors only.
- func_grad_name
option to change the name of the gradient function that shows up in the output.
- ...
further arguments passed to the functions func and func_grad.
Value
The return value contains a list with the analytic gradient, its finite difference approximation, the relative errors, and vector comparing the relative errors to the tolerance.
Examples
library('nloptr')
# example with correct gradient
f <- function(x, a) sum((x - a) ^ 2)
f_grad <- function(x, a) 2 * (x - a)
check.derivatives(.x = 1:10, func = f, func_grad = f_grad,
check_derivatives_print = 'none', a = runif(10))
#> Derivative checker results: 0 error(s) detected.
#> $analytic
#> [1] 1.983350 3.214606 4.372239 7.247503 9.238376 11.470163 13.121331
#> [8] 15.084786 16.918585 18.668640
#>
#> $finite_difference
#> [1] 1.983353 3.214607 4.372239 7.247503 9.238376 11.470163 13.121332
#> [8] 15.084786 16.918585 18.668641
#>
#> $relative_error
#> [1] 1.166881e-06 5.080167e-07 1.256393e-07 2.640727e-08 2.402307e-08
#> [6] 8.769085e-09 4.031029e-08 1.595958e-08 7.682158e-09 9.533407e-09
#>
#> $flag_derivative_warning
#> [1] FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
#>
# example with incorrect gradient
f_grad <- function(x, a) 2 * (x - a) + c(0, 0.1, rep(0, 8))
check.derivatives(.x = 1:10, func = f, func_grad = f_grad,
check_derivatives_print = 'errors', a = runif(10))
#> Derivative checker results: 1 error(s) detected.
#>
#> * grad_f[ 2] = 3.663266e+00 ~ 3.563267e+00 [2.806383e-02]
#>
#> $analytic
#> [1] 1.774602 3.663266 4.424327 7.804294 8.580339 11.564354 13.464113
#> [8] 14.990464 17.622826 19.121141
#>
#> $finite_difference
#> [1] 1.774601 3.563267 4.424328 7.804295 8.580339 11.564354 13.464113
#> [8] 14.990464 17.622826 19.121141
#>
#> $relative_error
#> [1] 6.436755e-07 2.806383e-02 1.232480e-07 8.526459e-08 2.018871e-09
#> [6] 2.536818e-09 1.106238e-08 7.610639e-09 1.837626e-08 4.383579e-09
#>
#> $flag_derivative_warning
#> [1] FALSE TRUE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
#>
# example with incorrect gradient of vector-valued function
g <- function(x, a) c(sum(x - a), sum((x - a) ^ 2))
g_grad <- function(x, a) {
rbind(rep(1, length(x)) + c(0, 0.01, rep(0, 8)),
2 * (x - a) + c(0, 0.1, rep(0, 8)))
}
check.derivatives(.x = 1:10, func = g, func_grad = g_grad,
check_derivatives_print = 'all', a = runif(10))
#> Derivative checker results: 2 error(s) detected.
#>
#> grad_f[1, 1] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 1] = 6.603614e-01 ~ 6.603584e-01 [4.511642e-06]
#> * grad_f[1, 2] = 1.010000e+00 ~ 1.000000e+00 [1.000000e-02]
#> * grad_f[2, 2] = 3.618234e+00 ~ 3.518232e+00 [2.842370e-02]
#> grad_f[1, 3] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 3] = 4.213470e+00 ~ 4.213469e+00 [2.423276e-07]
#> grad_f[1, 4] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 4] = 6.234487e+00 ~ 6.234487e+00 [1.036714e-07]
#> grad_f[1, 5] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 5] = 8.371873e+00 ~ 8.371873e+00 [2.013521e-09]
#> grad_f[1, 6] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 6] = 1.073347e+01 ~ 1.073347e+01 [4.682584e-08]
#> grad_f[1, 7] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 7] = 1.211782e+01 ~ 1.211782e+01 [2.939726e-08]
#> grad_f[1, 8] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 8] = 1.461218e+01 ~ 1.461218e+01 [1.752742e-08]
#> grad_f[1, 9] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 9] = 1.631260e+01 ~ 1.631260e+01 [9.477311e-09]
#> grad_f[1, 10] = 1.000000e+00 ~ 1.000000e+00 [0.000000e+00]
#> grad_f[2, 10] = 1.923071e+01 ~ 1.923071e+01 [5.012393e-09]
#>
#> $analytic
#> [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8]
#> [1,] 1.0000000 1.010000 1.00000 1.000000 1.000000 1.00000 1.00000 1.00000
#> [2,] 0.6603614 3.618234 4.21347 6.234487 8.371873 10.73347 12.11782 14.61218
#> [,9] [,10]
#> [1,] 1.0000 1.00000
#> [2,] 16.3126 19.23071
#>
#> $finite_difference
#> [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8]
#> [1,] 1.0000000 1.000000 1.000000 1.000000 1.000000 1.00000 1.00000 1.00000
#> [2,] 0.6603584 3.518232 4.213469 6.234487 8.371873 10.73347 12.11782 14.61218
#> [,9] [,10]
#> [1,] 1.0000 1.00000
#> [2,] 16.3126 19.23071
#>
#> $relative_error
#> [,1] [,2] [,3] [,4] [,5] [,6]
#> [1,] 0.000000e+00 0.0100000 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
#> [2,] 4.511642e-06 0.0284237 2.423276e-07 1.036714e-07 2.013521e-09 4.682584e-08
#> [,7] [,8] [,9] [,10]
#> [1,] 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
#> [2,] 2.939726e-08 1.752742e-08 9.477311e-09 5.012393e-09
#>
#> $flag_derivative_warning
#> [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
#> [1,] FALSE TRUE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
#> [2,] FALSE TRUE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
#>