Artificial Neural Networks: Linear Multiclass Classification (Part 3)
来源:互联网 发布:apk 反编译 源码 编辑:程序博客网 时间:2024/05/16 08:20
Artificial Neural Networks: Linear Multiclass Classification (Part 3)
In the last section, we went over how to use a linear neural network to perform classification. We covered using both the perceptron algorithm and gradient descent with a sigmoid activation function to learn the placement of the decision boundary in our feature space. However, we only covered binary classification. What if we instead want to classify a point belonging to one of
Theory
Multiclass classification using a linear neural network is a fairly simple extension of the binary classification setup. You may think that instead of outputting 0/1 from our second layer node, we could output 0, 1, ...,
Consider instead representing a label using a binary vector of length
Our training routine is exactly the same as in Part 2, except that the gradient of the multinomial logistic regression objective is slightly different:
Implementation
The implementation of multiclass linear classification doesn't change much from the binary case, except for the gradient and how we label our data points. For this toy example, we'll be generating 3 clusters of two-dimensional data. To make things more interesting, we won't restrict them to be linearly separable.
# Generate three random clusters of 2D data
N_c
=
200
A
=
0.6
*
np.random.randn(N_c,
2
)
+
[
1
,
1
]
B
=
0.6
*
np.random.randn(N_c,
2
)
+
[
3
,
3
]
C
=
0.6
*
np.random.randn(N_c,
2
)
+
[
3
,
0
]
X
=
np.hstack((np.ones(
3
*
N_c).reshape(
3
*
N_c,
1
), np.vstack((A, B, C))))
Y
=
np.vstack(((np.zeros(N_c)).reshape(N_c,
1
),
np.ones(N_c).reshape(N_c,
1
),
2
*
np.ones(N_c).reshape(N_c,
1
)))
K
=
3
N
=
K
*
N_c
Next we run gradient descent using the multinomial logistic regression gradient:
# Run gradient descent
eta
=
1E
-
2
max_iter
=
1000
w
=
np.zeros((
3
,
3
))
grad_thresh
=
5
for
t
in
range
(
0
, max_iter):
grad_t
=
np.zeros((
3
,
3
))
for
i
in
range
(
0
, N):
x_i
=
X[i, :]
y_i
=
Y[i]
exp_vals
=
np.exp(w.dot(x_i))
lik
=
exp_vals[
int
(y_i)]
/
np.
sum
(exp_vals)
grad_t[
int
(y_i), :]
+
=
x_i
*
(
1
-
lik)
w
=
w
+
1
/
float
(N)
*
eta
*
grad_t
grad_norm
=
np.linalg.norm(grad_t)
if
grad_norm < grad_thresh:
print
"Converged in "
,t
+
1
,
"steps."
break
if
t
=
=
max_iter
-
1
:
print
"Warning, did not converge."
There are a couple of things to note here. First, our weight vector
In line 27, we calculate the unnormalized likelihood using numpy's dot function. dot computes the dot product between the input
Running around 1,000 epochs with the given descent parameters will generate classification regions in the following manner:
Code Download
2D multi-label classification Python code.
Appendix
Deriving the gradient for multinomial logistic regression
- Artificial Neural Networks: Linear Multiclass Classification (Part 3)
- Artificial Neural Networks: Linear Classification (Part 2)
- Andrew NG 机器学习 练习3-Multiclass Classification and Neural Networks
- Artificial Neural Networks: Matrix Form (Part 5)
- Artificial Neural Networks: Mathematics of Backpropagation (Part 4)
- Artificial Neural Networks: Mathematics of Backpropagation (Part 4)
- 人工神经网络(Artificial Neural Networks)
- Artificial Neural Networks/Neural Networks/Neural Computing Conception
- LENSES CLASSIFICATION USING NEURAL NETWORKS
- 人工神经网络 Artificial Neural Networks - A Tutorial
- 人工神经网络(Artificial Neural Networks, ANN)
- Support Vector Machines vs Artificial Neural Networks
- Artificial Neural Networks && FileStorage of OpenCV
- 线性神经网络Linear Neural Networks
- ImageNet Classification with deep convolutional neural networks
- ImageNet Classification with Deep Convolutional neural Networks
- ImageNet Classification with Deep Convolutional Neural Networks
- ImageNet Classification with Deep Convolutional Neural Networks
- Android实现画板功能的多种实现方式
- S5PV210中的定时器
- 【步兵 c++】试用default_random_engine
- 选择排序的学习
- [Clean Code] Chapter 7: 异常处理
- Artificial Neural Networks: Linear Multiclass Classification (Part 3)
- KMP匹配算法实现详解
- 数据挖掘系列笔记(2):机器学习的应用实例
- import和@class的区别
- c混合运算和数据类型转换
- spring RestTemplate用法详解
- [实时渲染] 2.1 架构
- 《Python核心编程》第十二章:模块
- ~取反运算符原理