@@ -808,10 +808,12 @@ static int vsock_shutdown(struct socket
*/
sk = sock->sk;
+
+ lock_sock(sk);
if (sock->state == SS_UNCONNECTED) {
err = -ENOTCONN;
if (sk->sk_type == SOCK_STREAM)
- return err;
+ goto out;
} else {
sock->state = SS_DISCONNECTING;
err = 0;
@@ -820,10 +822,8 @@ static int vsock_shutdown(struct socket
/* Receive and send shutdowns are treated alike. */
mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN);
if (mode) {
- lock_sock(sk);
sk->sk_shutdown |= mode;
sk->sk_state_change(sk);
- release_sock(sk);
if (sk->sk_type == SOCK_STREAM) {
sock_reset_flag(sk, SOCK_DONE);
@@ -831,6 +831,8 @@ static int vsock_shutdown(struct socket
}
}
+out:
+ release_sock(sk);
return err;
}
@@ -464,14 +464,10 @@ static void hvs_shutdown_lock_held(struc
static int hvs_shutdown(struct vsock_sock *vsk, int mode)
{
- struct sock *sk = sk_vsock(vsk);
-
if (!(mode & SEND_SHUTDOWN))
return 0;
- lock_sock(sk);
hvs_shutdown_lock_held(vsk->trans, mode);
- release_sock(sk);
return 0;
}