@@ -83,8 +83,16 @@ static ssize_t store_sockfd(struct devic
}
socket = sockfd_lookup(sockfd, &err);
- if (!socket)
+ if (!socket) {
+ dev_err(dev, "failed to lookup sock");
goto err;
+ }
+
+ if (socket->type != SOCK_STREAM) {
+ dev_err(dev, "Expecting SOCK_STREAM - found %d",
+ socket->type);
+ goto sock_err;
+ }
sdev->ud.tcp_socket = socket;
sdev->ud.sockfd = sockfd;
@@ -114,6 +122,8 @@ static ssize_t store_sockfd(struct devic
return count;
+sock_err:
+ sockfd_put(socket);
err:
spin_unlock_irq(&sdev->ud.lock);
return -EINVAL;