Java 中的梯度下降
今天的大多数人工智能都是使用某种形式的神经网络实现的。 在我的前两篇文章中,我介绍了神经网络并向您展示了如何在 Java 中构建神经网络。 神经网络的力量主要来自于它的深度学习能力,而这种能力建立在梯度下降反向传播的概念和执行之上。 我将通过快速深入了解 Java 中的反向传播和梯度下降来结束这个简短的系列文章。
机器学习中的反向传播
有人说人工智能并不是那么聪明,它主要只是反向传播。 那么,现代机器学习的基石是什么?
要了解反向传播,您必须首先了解神经网络的工作原理。 基本上,神经网络是称为神经元的节点的有向图。 神经元有一个特定的结构,它接受输入,将它们与权重相乘,添加一个偏差值,并通过激活函数运行所有这些。 神经元将它们的输出馈送到其他神经元,直到到达输出神经元。 输出神经元产生网络的输出。 (有关更完整的介绍,请参阅机器学习风格:神经网络简介。)
从这里开始,我假设您了解网络及其神经元的结构,包括前馈。 示例和讨论将集中在梯度下降的反向传播上。 我们的神经网络将有一个输出节点、两个“隐藏”节点和两个输入节点。 使用一个相对简单的例子可以更容易地理解算法所涉及的数学。 图 1 显示了示例神经网络的图表。
国际数据集团
图 1. 我们将用于示例的神经网络图。
梯度下降反向传播的思想是将整个网络视为一个多元函数,为损失函数提供输入。 损失函数通过将网络输出与已知的良好结果进行比较来计算一个表示网络执行情况的数字。 与良好结果配对的输入数据集称为训练集。 损失函数旨在随着网络行为远离正确而增加数值。
梯度下降算法采用损失函数并使用偏导数来确定网络中每个变量(权重和偏差)对损失值的贡献。 然后它向后移动,访问每个变量并调整它以减少损失值。
梯度下降的演算
理解梯度下降涉及微积分中的一些概念。 首先是导数的概念。 MathsIsFun.com 对导数有很好的介绍。 简而言之,导数为您提供函数在单个点处的斜率(或变化率)。 换句话说,函数的导数为我们提供了给定输入的变化率。 (微积分的美妙之处在于它可以让我们在没有其他参考点的情况下找到变化——或者更确切地说,它可以让我们假设输入的变化是无限小的。)
下一个重要的概念是偏导数。 偏导数让我们采用多维(也称为多变量)函数并仅隔离其中一个变量以找到给定维度的斜率。
导数回答了以下问题:函数在特定点的变化率(或斜率)是多少? 偏导数回答了以下问题:给定方程的多个输入变量,仅这一个变量的变化率是多少?
梯度下降使用这些思想来访问方程中的每个变量并对其进行调整以最小化方程的输出。 这正是我们在训练我们的网络时想要的。 如果我们将损失函数视为绘制在图表上,我们希望以增量方式向函数的最小值移动。 也就是说,我们要找到全局最小值。
请注意,增量的大小在机器学习中称为“学习率”。
代码中的梯度下降
当我们探索梯度下降反向传播的数学时,我们将紧贴代码。 当数学变得过于抽象时,查看代码将有助于让我们脚踏实地。 让我们先看看我们的 Neuron
类,如清单 1 所示。
清单 1. 一个 Neuron 类
class Neuron {
Random random = new Random();
private Double bias = random.nextGaussian();
private Double weight1 = random.nextGaussian();
private Double weight2 = random.nextGaussian();
public double compute(double input1, double input2){
return Util.sigmoid(this.getSum(input1, input2));
}
public Double getWeight1() { return this.weight1; }
public Double getWeight2() { return this.weight2; }
public Double getSum(double input1, double input2){ return (this.weight1 * input1) + (this.weight2 * input2) + this.bias; }
public Double getDerivedOutput(double input1, double input2){ return Util.sigmoidDeriv(this.getSum(input1, input2)); }
public void adjust(Double w1, Double w2, Double b){
this.weight1 -= w1; this.weight2 -= w2; this.bias -= b;
}
}
这 Neuron
班级只有三个 Double
成员: weight1
, weight2
, 和 bias
. 它也有一些方法。 用于前馈的方法是 compute()
. 它接受两个输入并执行神经元的工作:将每个输入乘以适当的权重,加上偏差,然后通过 sigmoid 函数运行它。
在我们继续之前,让我们重新审视一下 sigmoid 激活的概念,我在神经网络简介中也讨论过它。 清单 2 显示了一个基于 Java 的 sigmoid 激活函数。
清单 2. Util.sigmoid()
public static double sigmoid(double in){
return 1 / (1 + Math.exp(-in));
}
sigmoid 函数接受输入并将欧拉数 (Math.exp) 提高到负数,加 1 再除以 1。效果是将输出压缩在 0 和 1 之间,较大和较小的数字渐近地接近极限。
什么是 sigmoid 函数?
DeepAI.org 对机器学习中的 sigmoid 函数有很好的介绍。
回到 Neuron
清单 1 中的类,除了 compute()
我们有的方法 getSum()
和 getDerivedOutput()
. getSum()
只是进行权重 * 输入 + 偏差计算。 请注意 compute()
需要 getSum()
并运行它 sigmoid()
. 这 getDerivedOutput()
方法运行 getSum()
通过不同的函数:sigmoid 函数的导数。
导数在行动
现在看一下清单 3,它显示了 Java 中的 sigmoid 导数函数。 我们已经从概念上讨论了衍生品,下面是实际应用。
清单 3. Sigmoid 导数
public static double sigmoidDeriv(double in){
double sigmoid = Util.sigmoid(in);
return sigmoid * (1 - sigmoid);
}
记住导数告诉我们函数在其图中的单个点的变化是什么,我们可以感受一下这个导数在说什么:告诉我给定输入的 sigmoid 函数的变化率。 您可以说它告诉我们清单 1 中的预激活神经元对最终激活结果有何影响。
衍生规则
您可能想知道我们如何知道清单 3 中的 sigmoid 导数函数是正确的。 答案是,如果它已经被其他人验证过,并且如果我们知道根据特定规则正确微分的函数是准确的,我们就会知道导数函数是正确的。 一旦我们理解了它们所说的内容并相信它们是准确的,我们就不必回到第一原理并重新发现这些规则——就像我们接受并应用简化代数方程的规则一样。
所以,在实践中,我们是按照求导法则来求导数的。 如果您查看 sigmoid 函数及其导数,您会发现后者可以通过遵循这些规则得出。 出于梯度下降的目的,我们需要了解导数规则,相信它们有效,并了解它们的应用方式。 我们将使用它们来找出每个权重和偏差在网络最终损失结果中所扮演的角色。
符号
符号 f prime f'(x) 是表示“f 对 x 的导数”的一种方式。 另一个是:
国际数据集团
两者是等价的:
国际数据集团
您很快就会看到的另一种表示法是偏导数表示法:
国际数据集团
这就是说,给我变量 x 的 f 的导数。
连锁法则
最令人好奇的衍生规则是链式法则。 它说当一个函数是复合的(函数中的函数,又名高阶函数)时,你可以像这样扩展它:
国际数据集团
我们将使用链式法则来解压缩我们的网络并获得每个权重和偏差的偏导数。