net: Work around lockdep limitation in sockets that use sockets

Lockdep issues a circular dependency warning when AFS issues an operation
through AF_RXRPC from a context in which the VFS/VM holds the mmap_sem.

The theory lockdep comes up with is as follows:

 (1) If the pagefault handler decides it needs to read pages from AFS, it
     calls AFS with mmap_sem held and AFS begins an AF_RXRPC call, but
     creating a call requires the socket lock:

	mmap_sem must be taken before sk_lock-AF_RXRPC

 (2) afs_open_socket() opens an AF_RXRPC socket and binds it.  rxrpc_bind()
     binds the underlying UDP socket whilst holding its socket lock.
     inet_bind() takes its own socket lock:

	sk_lock-AF_RXRPC must be taken before sk_lock-AF_INET

 (3) Reading from a TCP socket into a userspace buffer might cause a fault
     and thus cause the kernel to take the mmap_sem, but the TCP socket is
     locked whilst doing this:

	sk_lock-AF_INET must be taken before mmap_sem

However, lockdep's theory is wrong in this instance because it deals only
with lock classes and not individual locks.  The AF_INET lock in (2) isn't
really equivalent to the AF_INET lock in (3) as the former deals with a
socket entirely internal to the kernel that never sees userspace.  This is
a limitation in the design of lockdep.

Fix the general case by:

 (1) Double up all the locking keys used in sockets so that one set are
     used if the socket is created by userspace and the other set is used
     if the socket is created by the kernel.

 (2) Store the kern parameter passed to sk_alloc() in a variable in the
     sock struct (sk_kern_sock).  This informs sock_lock_init(),
     sock_init_data() and sk_clone_lock() as to the lock keys to be used.

     Note that the child created by sk_clone_lock() inherits the parent's
     kern setting.

 (3) Add a 'kern' parameter to ->accept() that is analogous to the one
     passed in to ->create() that distinguishes whether kernel_accept() or
     sys_accept4() was the caller and can be passed to sk_alloc().

     Note that a lot of accept functions merely dequeue an already
     allocated socket.  I haven't touched these as the new socket already
     exists before we get the parameter.

     Note also that there are a couple of places where I've made the accepted
     socket unconditionally kernel-based:

	irda_accept()
	rds_rcp_accept_one()
	tcp_accept_from_sock()

     because they follow a sock_create_kern() and accept off of that.

Whilst creating this, I noticed that lustre and ocfs don't create sockets
through sock_create_kern() and thus they aren't marked as for-kernel,
though they appear to be internal.  I wonder if these should do that so
that they use the new set of lock keys.

Signed-off-by: David Howells <dhowells@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
David Howells 2017-03-09 08:09:05 +00:00 committed by David S. Miller
parent 81dca07b3b
commit cdfbabfb2f
38 changed files with 142 additions and 108 deletions

View File

@ -266,7 +266,7 @@ static int alg_setsockopt(struct socket *sock, int level, int optname,
return err; return err;
} }
int af_alg_accept(struct sock *sk, struct socket *newsock) int af_alg_accept(struct sock *sk, struct socket *newsock, bool kern)
{ {
struct alg_sock *ask = alg_sk(sk); struct alg_sock *ask = alg_sk(sk);
const struct af_alg_type *type; const struct af_alg_type *type;
@ -281,7 +281,7 @@ int af_alg_accept(struct sock *sk, struct socket *newsock)
if (!type) if (!type)
goto unlock; goto unlock;
sk2 = sk_alloc(sock_net(sk), PF_ALG, GFP_KERNEL, &alg_proto, 0); sk2 = sk_alloc(sock_net(sk), PF_ALG, GFP_KERNEL, &alg_proto, kern);
err = -ENOMEM; err = -ENOMEM;
if (!sk2) if (!sk2)
goto unlock; goto unlock;
@ -323,9 +323,10 @@ int af_alg_accept(struct sock *sk, struct socket *newsock)
} }
EXPORT_SYMBOL_GPL(af_alg_accept); EXPORT_SYMBOL_GPL(af_alg_accept);
static int alg_accept(struct socket *sock, struct socket *newsock, int flags) static int alg_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
return af_alg_accept(sock->sk, newsock); return af_alg_accept(sock->sk, newsock, kern);
} }
static const struct proto_ops alg_proto_ops = { static const struct proto_ops alg_proto_ops = {

View File

@ -239,7 +239,8 @@ static int hash_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
return err ?: len; return err ?: len;
} }
static int hash_accept(struct socket *sock, struct socket *newsock, int flags) static int hash_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct alg_sock *ask = alg_sk(sk); struct alg_sock *ask = alg_sk(sk);
@ -260,7 +261,7 @@ static int hash_accept(struct socket *sock, struct socket *newsock, int flags)
if (err) if (err)
return err; return err;
err = af_alg_accept(ask->parent, newsock); err = af_alg_accept(ask->parent, newsock, kern);
if (err) if (err)
return err; return err;
@ -378,7 +379,7 @@ static int hash_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
} }
static int hash_accept_nokey(struct socket *sock, struct socket *newsock, static int hash_accept_nokey(struct socket *sock, struct socket *newsock,
int flags) int flags, bool kern)
{ {
int err; int err;
@ -386,7 +387,7 @@ static int hash_accept_nokey(struct socket *sock, struct socket *newsock,
if (err) if (err)
return err; return err;
return hash_accept(sock, newsock, flags); return hash_accept(sock, newsock, flags, kern);
} }
static struct proto_ops algif_hash_ops_nokey = { static struct proto_ops algif_hash_ops_nokey = {

View File

@ -532,7 +532,7 @@ lnet_sock_accept(struct socket **newsockp, struct socket *sock)
newsock->ops = sock->ops; newsock->ops = sock->ops;
rc = sock->ops->accept(sock, newsock, O_NONBLOCK); rc = sock->ops->accept(sock, newsock, O_NONBLOCK, false);
if (rc == -EAGAIN) { if (rc == -EAGAIN) {
/* Nothing ready, so wait for activity */ /* Nothing ready, so wait for activity */
init_waitqueue_entry(&wait, current); init_waitqueue_entry(&wait, current);
@ -540,7 +540,7 @@ lnet_sock_accept(struct socket **newsockp, struct socket *sock)
set_current_state(TASK_INTERRUPTIBLE); set_current_state(TASK_INTERRUPTIBLE);
schedule(); schedule();
remove_wait_queue(sk_sleep(sock->sk), &wait); remove_wait_queue(sk_sleep(sock->sk), &wait);
rc = sock->ops->accept(sock, newsock, O_NONBLOCK); rc = sock->ops->accept(sock, newsock, O_NONBLOCK, false);
} }
if (rc) if (rc)

View File

@ -743,7 +743,7 @@ static int tcp_accept_from_sock(struct connection *con)
newsock->type = con->sock->type; newsock->type = con->sock->type;
newsock->ops = con->sock->ops; newsock->ops = con->sock->ops;
result = con->sock->ops->accept(con->sock, newsock, O_NONBLOCK); result = con->sock->ops->accept(con->sock, newsock, O_NONBLOCK, true);
if (result < 0) if (result < 0)
goto accept_err; goto accept_err;

View File

@ -1863,7 +1863,7 @@ static int o2net_accept_one(struct socket *sock, int *more)
new_sock->type = sock->type; new_sock->type = sock->type;
new_sock->ops = sock->ops; new_sock->ops = sock->ops;
ret = sock->ops->accept(sock, new_sock, O_NONBLOCK); ret = sock->ops->accept(sock, new_sock, O_NONBLOCK, false);
if (ret < 0) if (ret < 0)
goto out; goto out;

View File

@ -73,7 +73,7 @@ int af_alg_unregister_type(const struct af_alg_type *type);
int af_alg_release(struct socket *sock); int af_alg_release(struct socket *sock);
void af_alg_release_parent(struct sock *sk); void af_alg_release_parent(struct sock *sk);
int af_alg_accept(struct sock *sk, struct socket *newsock); int af_alg_accept(struct sock *sk, struct socket *newsock, bool kern);
int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len); int af_alg_make_sg(struct af_alg_sgl *sgl, struct iov_iter *iter, int len);
void af_alg_free_sg(struct af_alg_sgl *sgl); void af_alg_free_sg(struct af_alg_sgl *sgl);

View File

@ -146,7 +146,7 @@ struct proto_ops {
int (*socketpair)(struct socket *sock1, int (*socketpair)(struct socket *sock1,
struct socket *sock2); struct socket *sock2);
int (*accept) (struct socket *sock, int (*accept) (struct socket *sock,
struct socket *newsock, int flags); struct socket *newsock, int flags, bool kern);
int (*getname) (struct socket *sock, int (*getname) (struct socket *sock,
struct sockaddr *addr, struct sockaddr *addr,
int *sockaddr_len, int peer); int *sockaddr_len, int peer);

View File

@ -20,7 +20,8 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
int addr_len, int flags, int is_sendmsg); int addr_len, int flags, int is_sendmsg);
int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr, int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr,
int addr_len, int flags); int addr_len, int flags);
int inet_accept(struct socket *sock, struct socket *newsock, int flags); int inet_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern);
int inet_sendmsg(struct socket *sock, struct msghdr *msg, size_t size); int inet_sendmsg(struct socket *sock, struct msghdr *msg, size_t size);
ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset, ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset,
size_t size, int flags); size_t size, int flags);

View File

@ -258,7 +258,7 @@ inet_csk_rto_backoff(const struct inet_connection_sock *icsk,
return (unsigned long)min_t(u64, when, max_when); return (unsigned long)min_t(u64, when, max_when);
} }
struct sock *inet_csk_accept(struct sock *sk, int flags, int *err); struct sock *inet_csk_accept(struct sock *sk, int flags, int *err, bool kern);
int inet_csk_get_port(struct sock *sk, unsigned short snum); int inet_csk_get_port(struct sock *sk, unsigned short snum);

View File

@ -476,7 +476,8 @@ struct sctp_pf {
int (*send_verify) (struct sctp_sock *, union sctp_addr *); int (*send_verify) (struct sctp_sock *, union sctp_addr *);
int (*supported_addrs)(const struct sctp_sock *, __be16 *); int (*supported_addrs)(const struct sctp_sock *, __be16 *);
struct sock *(*create_accept_sk) (struct sock *sk, struct sock *(*create_accept_sk) (struct sock *sk,
struct sctp_association *asoc); struct sctp_association *asoc,
bool kern);
int (*addr_to_user)(struct sctp_sock *sk, union sctp_addr *addr); int (*addr_to_user)(struct sctp_sock *sk, union sctp_addr *addr);
void (*to_sk_saddr)(union sctp_addr *, struct sock *sk); void (*to_sk_saddr)(union sctp_addr *, struct sock *sk);
void (*to_sk_daddr)(union sctp_addr *, struct sock *sk); void (*to_sk_daddr)(union sctp_addr *, struct sock *sk);

View File

@ -236,6 +236,7 @@ struct sock_common {
* @sk_shutdown: mask of %SEND_SHUTDOWN and/or %RCV_SHUTDOWN * @sk_shutdown: mask of %SEND_SHUTDOWN and/or %RCV_SHUTDOWN
* @sk_userlocks: %SO_SNDBUF and %SO_RCVBUF settings * @sk_userlocks: %SO_SNDBUF and %SO_RCVBUF settings
* @sk_lock: synchronizer * @sk_lock: synchronizer
* @sk_kern_sock: True if sock is using kernel lock classes
* @sk_rcvbuf: size of receive buffer in bytes * @sk_rcvbuf: size of receive buffer in bytes
* @sk_wq: sock wait queue and async head * @sk_wq: sock wait queue and async head
* @sk_rx_dst: receive input route used by early demux * @sk_rx_dst: receive input route used by early demux
@ -430,7 +431,8 @@ struct sock {
#endif #endif
kmemcheck_bitfield_begin(flags); kmemcheck_bitfield_begin(flags);
unsigned int sk_padding : 2, unsigned int sk_padding : 1,
sk_kern_sock : 1,
sk_no_check_tx : 1, sk_no_check_tx : 1,
sk_no_check_rx : 1, sk_no_check_rx : 1,
sk_userlocks : 4, sk_userlocks : 4,
@ -1015,7 +1017,8 @@ struct proto {
int addr_len); int addr_len);
int (*disconnect)(struct sock *sk, int flags); int (*disconnect)(struct sock *sk, int flags);
struct sock * (*accept)(struct sock *sk, int flags, int *err); struct sock * (*accept)(struct sock *sk, int flags, int *err,
bool kern);
int (*ioctl)(struct sock *sk, int cmd, int (*ioctl)(struct sock *sk, int cmd,
unsigned long arg); unsigned long arg);
@ -1573,7 +1576,7 @@ int sock_cmsg_send(struct sock *sk, struct msghdr *msg,
int sock_no_bind(struct socket *, struct sockaddr *, int); int sock_no_bind(struct socket *, struct sockaddr *, int);
int sock_no_connect(struct socket *, struct sockaddr *, int, int); int sock_no_connect(struct socket *, struct sockaddr *, int, int);
int sock_no_socketpair(struct socket *, struct socket *); int sock_no_socketpair(struct socket *, struct socket *);
int sock_no_accept(struct socket *, struct socket *, int); int sock_no_accept(struct socket *, struct socket *, int, bool);
int sock_no_getname(struct socket *, struct sockaddr *, int *, int); int sock_no_getname(struct socket *, struct sockaddr *, int *, int);
unsigned int sock_no_poll(struct file *, struct socket *, unsigned int sock_no_poll(struct file *, struct socket *,
struct poll_table_struct *); struct poll_table_struct *);

View File

@ -318,7 +318,8 @@ static int svc_listen(struct socket *sock, int backlog)
return error; return error;
} }
static int svc_accept(struct socket *sock, struct socket *newsock, int flags) static int svc_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct sk_buff *skb; struct sk_buff *skb;
@ -329,7 +330,7 @@ static int svc_accept(struct socket *sock, struct socket *newsock, int flags)
lock_sock(sk); lock_sock(sk);
error = svc_create(sock_net(sk), newsock, 0, 0); error = svc_create(sock_net(sk), newsock, 0, kern);
if (error) if (error)
goto out; goto out;

View File

@ -1320,7 +1320,8 @@ static int __must_check ax25_connect(struct socket *sock,
return err; return err;
} }
static int ax25_accept(struct socket *sock, struct socket *newsock, int flags) static int ax25_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sk_buff *skb; struct sk_buff *skb;
struct sock *newsk; struct sock *newsk;

View File

@ -301,7 +301,7 @@ static int l2cap_sock_listen(struct socket *sock, int backlog)
} }
static int l2cap_sock_accept(struct socket *sock, struct socket *newsock, static int l2cap_sock_accept(struct socket *sock, struct socket *newsock,
int flags) int flags, bool kern)
{ {
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
struct sock *sk = sock->sk, *nsk; struct sock *sk = sock->sk, *nsk;

View File

@ -471,7 +471,8 @@ static int rfcomm_sock_listen(struct socket *sock, int backlog)
return err; return err;
} }
static int rfcomm_sock_accept(struct socket *sock, struct socket *newsock, int flags) static int rfcomm_sock_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
struct sock *sk = sock->sk, *nsk; struct sock *sk = sock->sk, *nsk;

View File

@ -627,7 +627,7 @@ static int sco_sock_listen(struct socket *sock, int backlog)
} }
static int sco_sock_accept(struct socket *sock, struct socket *newsock, static int sco_sock_accept(struct socket *sock, struct socket *newsock,
int flags) int flags, bool kern)
{ {
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
struct sock *sk = sock->sk, *ch; struct sock *sk = sock->sk, *ch;

View File

@ -197,66 +197,55 @@ EXPORT_SYMBOL(sk_net_capable);
/* /*
* Each address family might have different locking rules, so we have * Each address family might have different locking rules, so we have
* one slock key per address family: * one slock key per address family and separate keys for internal and
* userspace sockets.
*/ */
static struct lock_class_key af_family_keys[AF_MAX]; static struct lock_class_key af_family_keys[AF_MAX];
static struct lock_class_key af_family_kern_keys[AF_MAX];
static struct lock_class_key af_family_slock_keys[AF_MAX]; static struct lock_class_key af_family_slock_keys[AF_MAX];
static struct lock_class_key af_family_kern_slock_keys[AF_MAX];
/* /*
* Make lock validator output more readable. (we pre-construct these * Make lock validator output more readable. (we pre-construct these
* strings build-time, so that runtime initialization of socket * strings build-time, so that runtime initialization of socket
* locks is fast): * locks is fast):
*/ */
#define _sock_locks(x) \
x "AF_UNSPEC", x "AF_UNIX" , x "AF_INET" , \
x "AF_AX25" , x "AF_IPX" , x "AF_APPLETALK", \
x "AF_NETROM", x "AF_BRIDGE" , x "AF_ATMPVC" , \
x "AF_X25" , x "AF_INET6" , x "AF_ROSE" , \
x "AF_DECnet", x "AF_NETBEUI" , x "AF_SECURITY" , \
x "AF_KEY" , x "AF_NETLINK" , x "AF_PACKET" , \
x "AF_ASH" , x "AF_ECONET" , x "AF_ATMSVC" , \
x "AF_RDS" , x "AF_SNA" , x "AF_IRDA" , \
x "AF_PPPOX" , x "AF_WANPIPE" , x "AF_LLC" , \
x "27" , x "28" , x "AF_CAN" , \
x "AF_TIPC" , x "AF_BLUETOOTH", x "IUCV" , \
x "AF_RXRPC" , x "AF_ISDN" , x "AF_PHONET" , \
x "AF_IEEE802154", x "AF_CAIF" , x "AF_ALG" , \
x "AF_NFC" , x "AF_VSOCK" , x "AF_KCM" , \
x "AF_QIPCRTR", x "AF_SMC" , x "AF_MAX"
static const char *const af_family_key_strings[AF_MAX+1] = { static const char *const af_family_key_strings[AF_MAX+1] = {
"sk_lock-AF_UNSPEC", "sk_lock-AF_UNIX" , "sk_lock-AF_INET" , _sock_locks("sk_lock-")
"sk_lock-AF_AX25" , "sk_lock-AF_IPX" , "sk_lock-AF_APPLETALK",
"sk_lock-AF_NETROM", "sk_lock-AF_BRIDGE" , "sk_lock-AF_ATMPVC" ,
"sk_lock-AF_X25" , "sk_lock-AF_INET6" , "sk_lock-AF_ROSE" ,
"sk_lock-AF_DECnet", "sk_lock-AF_NETBEUI" , "sk_lock-AF_SECURITY" ,
"sk_lock-AF_KEY" , "sk_lock-AF_NETLINK" , "sk_lock-AF_PACKET" ,
"sk_lock-AF_ASH" , "sk_lock-AF_ECONET" , "sk_lock-AF_ATMSVC" ,
"sk_lock-AF_RDS" , "sk_lock-AF_SNA" , "sk_lock-AF_IRDA" ,
"sk_lock-AF_PPPOX" , "sk_lock-AF_WANPIPE" , "sk_lock-AF_LLC" ,
"sk_lock-27" , "sk_lock-28" , "sk_lock-AF_CAN" ,
"sk_lock-AF_TIPC" , "sk_lock-AF_BLUETOOTH", "sk_lock-IUCV" ,
"sk_lock-AF_RXRPC" , "sk_lock-AF_ISDN" , "sk_lock-AF_PHONET" ,
"sk_lock-AF_IEEE802154", "sk_lock-AF_CAIF" , "sk_lock-AF_ALG" ,
"sk_lock-AF_NFC" , "sk_lock-AF_VSOCK" , "sk_lock-AF_KCM" ,
"sk_lock-AF_QIPCRTR", "sk_lock-AF_SMC" , "sk_lock-AF_MAX"
}; };
static const char *const af_family_slock_key_strings[AF_MAX+1] = { static const char *const af_family_slock_key_strings[AF_MAX+1] = {
"slock-AF_UNSPEC", "slock-AF_UNIX" , "slock-AF_INET" , _sock_locks("slock-")
"slock-AF_AX25" , "slock-AF_IPX" , "slock-AF_APPLETALK",
"slock-AF_NETROM", "slock-AF_BRIDGE" , "slock-AF_ATMPVC" ,
"slock-AF_X25" , "slock-AF_INET6" , "slock-AF_ROSE" ,
"slock-AF_DECnet", "slock-AF_NETBEUI" , "slock-AF_SECURITY" ,
"slock-AF_KEY" , "slock-AF_NETLINK" , "slock-AF_PACKET" ,
"slock-AF_ASH" , "slock-AF_ECONET" , "slock-AF_ATMSVC" ,
"slock-AF_RDS" , "slock-AF_SNA" , "slock-AF_IRDA" ,
"slock-AF_PPPOX" , "slock-AF_WANPIPE" , "slock-AF_LLC" ,
"slock-27" , "slock-28" , "slock-AF_CAN" ,
"slock-AF_TIPC" , "slock-AF_BLUETOOTH", "slock-AF_IUCV" ,
"slock-AF_RXRPC" , "slock-AF_ISDN" , "slock-AF_PHONET" ,
"slock-AF_IEEE802154", "slock-AF_CAIF" , "slock-AF_ALG" ,
"slock-AF_NFC" , "slock-AF_VSOCK" ,"slock-AF_KCM" ,
"slock-AF_QIPCRTR", "slock-AF_SMC" , "slock-AF_MAX"
}; };
static const char *const af_family_clock_key_strings[AF_MAX+1] = { static const char *const af_family_clock_key_strings[AF_MAX+1] = {
"clock-AF_UNSPEC", "clock-AF_UNIX" , "clock-AF_INET" , _sock_locks("clock-")
"clock-AF_AX25" , "clock-AF_IPX" , "clock-AF_APPLETALK", };
"clock-AF_NETROM", "clock-AF_BRIDGE" , "clock-AF_ATMPVC" ,
"clock-AF_X25" , "clock-AF_INET6" , "clock-AF_ROSE" , static const char *const af_family_kern_key_strings[AF_MAX+1] = {
"clock-AF_DECnet", "clock-AF_NETBEUI" , "clock-AF_SECURITY" , _sock_locks("k-sk_lock-")
"clock-AF_KEY" , "clock-AF_NETLINK" , "clock-AF_PACKET" , };
"clock-AF_ASH" , "clock-AF_ECONET" , "clock-AF_ATMSVC" , static const char *const af_family_kern_slock_key_strings[AF_MAX+1] = {
"clock-AF_RDS" , "clock-AF_SNA" , "clock-AF_IRDA" , _sock_locks("k-slock-")
"clock-AF_PPPOX" , "clock-AF_WANPIPE" , "clock-AF_LLC" , };
"clock-27" , "clock-28" , "clock-AF_CAN" , static const char *const af_family_kern_clock_key_strings[AF_MAX+1] = {
"clock-AF_TIPC" , "clock-AF_BLUETOOTH", "clock-AF_IUCV" , _sock_locks("k-clock-")
"clock-AF_RXRPC" , "clock-AF_ISDN" , "clock-AF_PHONET" ,
"clock-AF_IEEE802154", "clock-AF_CAIF" , "clock-AF_ALG" ,
"clock-AF_NFC" , "clock-AF_VSOCK" , "clock-AF_KCM" ,
"clock-AF_QIPCRTR", "clock-AF_SMC" , "clock-AF_MAX"
}; };
/* /*
@ -264,6 +253,7 @@ static const char *const af_family_clock_key_strings[AF_MAX+1] = {
* so split the lock classes by using a per-AF key: * so split the lock classes by using a per-AF key:
*/ */
static struct lock_class_key af_callback_keys[AF_MAX]; static struct lock_class_key af_callback_keys[AF_MAX];
static struct lock_class_key af_kern_callback_keys[AF_MAX];
/* Take into consideration the size of the struct sk_buff overhead in the /* Take into consideration the size of the struct sk_buff overhead in the
* determination of these values, since that is non-constant across * determination of these values, since that is non-constant across
@ -1293,7 +1283,16 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
*/ */
static inline void sock_lock_init(struct sock *sk) static inline void sock_lock_init(struct sock *sk)
{ {
sock_lock_init_class_and_name(sk, if (sk->sk_kern_sock)
sock_lock_init_class_and_name(
sk,
af_family_kern_slock_key_strings[sk->sk_family],
af_family_kern_slock_keys + sk->sk_family,
af_family_kern_key_strings[sk->sk_family],
af_family_kern_keys + sk->sk_family);
else
sock_lock_init_class_and_name(
sk,
af_family_slock_key_strings[sk->sk_family], af_family_slock_key_strings[sk->sk_family],
af_family_slock_keys + sk->sk_family, af_family_slock_keys + sk->sk_family,
af_family_key_strings[sk->sk_family], af_family_key_strings[sk->sk_family],
@ -1399,6 +1398,7 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
* why we need sk_prot_creator -acme * why we need sk_prot_creator -acme
*/ */
sk->sk_prot = sk->sk_prot_creator = prot; sk->sk_prot = sk->sk_prot_creator = prot;
sk->sk_kern_sock = kern;
sock_lock_init(sk); sock_lock_init(sk);
sk->sk_net_refcnt = kern ? 0 : 1; sk->sk_net_refcnt = kern ? 0 : 1;
if (likely(sk->sk_net_refcnt)) if (likely(sk->sk_net_refcnt))
@ -2277,7 +2277,8 @@ int sock_no_socketpair(struct socket *sock1, struct socket *sock2)
} }
EXPORT_SYMBOL(sock_no_socketpair); EXPORT_SYMBOL(sock_no_socketpair);
int sock_no_accept(struct socket *sock, struct socket *newsock, int flags) int sock_no_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
return -EOPNOTSUPP; return -EOPNOTSUPP;
} }
@ -2481,7 +2482,14 @@ void sock_init_data(struct socket *sock, struct sock *sk)
} }
rwlock_init(&sk->sk_callback_lock); rwlock_init(&sk->sk_callback_lock);
lockdep_set_class_and_name(&sk->sk_callback_lock, if (sk->sk_kern_sock)
lockdep_set_class_and_name(
&sk->sk_callback_lock,
af_kern_callback_keys + sk->sk_family,
af_family_kern_clock_key_strings[sk->sk_family]);
else
lockdep_set_class_and_name(
&sk->sk_callback_lock,
af_callback_keys + sk->sk_family, af_callback_keys + sk->sk_family,
af_family_clock_key_strings[sk->sk_family]); af_family_clock_key_strings[sk->sk_family]);

View File

@ -1070,7 +1070,8 @@ static struct sk_buff *dn_wait_for_connect(struct sock *sk, long *timeo)
return skb == NULL ? ERR_PTR(err) : skb; return skb == NULL ? ERR_PTR(err) : skb;
} }
static int dn_accept(struct socket *sock, struct socket *newsock, int flags) static int dn_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *sk = sock->sk, *newsk; struct sock *sk = sock->sk, *newsk;
struct sk_buff *skb = NULL; struct sk_buff *skb = NULL;
@ -1099,7 +1100,7 @@ static int dn_accept(struct socket *sock, struct socket *newsock, int flags)
cb = DN_SKB_CB(skb); cb = DN_SKB_CB(skb);
sk->sk_ack_backlog--; sk->sk_ack_backlog--;
newsk = dn_alloc_sock(sock_net(sk), newsock, sk->sk_allocation, 0); newsk = dn_alloc_sock(sock_net(sk), newsock, sk->sk_allocation, kern);
if (newsk == NULL) { if (newsk == NULL) {
release_sock(sk); release_sock(sk);
kfree_skb(skb); kfree_skb(skb);

View File

@ -689,11 +689,12 @@ EXPORT_SYMBOL(inet_stream_connect);
* Accept a pending connection. The TCP layer now gives BSD semantics. * Accept a pending connection. The TCP layer now gives BSD semantics.
*/ */
int inet_accept(struct socket *sock, struct socket *newsock, int flags) int inet_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *sk1 = sock->sk; struct sock *sk1 = sock->sk;
int err = -EINVAL; int err = -EINVAL;
struct sock *sk2 = sk1->sk_prot->accept(sk1, flags, &err); struct sock *sk2 = sk1->sk_prot->accept(sk1, flags, &err, kern);
if (!sk2) if (!sk2)
goto do_err; goto do_err;

View File

@ -424,7 +424,7 @@ static int inet_csk_wait_for_connect(struct sock *sk, long timeo)
/* /*
* This will accept the next outstanding connection. * This will accept the next outstanding connection.
*/ */
struct sock *inet_csk_accept(struct sock *sk, int flags, int *err) struct sock *inet_csk_accept(struct sock *sk, int flags, int *err, bool kern)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);
struct request_sock_queue *queue = &icsk->icsk_accept_queue; struct request_sock_queue *queue = &icsk->icsk_accept_queue;

View File

@ -828,7 +828,8 @@ static int irda_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
* Wait for incoming connection * Wait for incoming connection
* *
*/ */
static int irda_accept(struct socket *sock, struct socket *newsock, int flags) static int irda_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct irda_sock *new, *self = irda_sk(sk); struct irda_sock *new, *self = irda_sk(sk);
@ -836,7 +837,7 @@ static int irda_accept(struct socket *sock, struct socket *newsock, int flags)
struct sk_buff *skb = NULL; struct sk_buff *skb = NULL;
int err; int err;
err = irda_create(sock_net(sk), newsock, sk->sk_protocol, 0); err = irda_create(sock_net(sk), newsock, sk->sk_protocol, kern);
if (err) if (err)
return err; return err;

View File

@ -938,7 +938,7 @@ static int iucv_sock_listen(struct socket *sock, int backlog)
/* Accept a pending connection */ /* Accept a pending connection */
static int iucv_sock_accept(struct socket *sock, struct socket *newsock, static int iucv_sock_accept(struct socket *sock, struct socket *newsock,
int flags) int flags, bool kern)
{ {
DECLARE_WAITQUEUE(wait, current); DECLARE_WAITQUEUE(wait, current);
struct sock *sk = sock->sk, *nsk; struct sock *sk = sock->sk, *nsk;

View File

@ -641,11 +641,13 @@ static void llc_cmsg_rcv(struct msghdr *msg, struct sk_buff *skb)
* @sock: Socket which connections arrive on. * @sock: Socket which connections arrive on.
* @newsock: Socket to move incoming connection to. * @newsock: Socket to move incoming connection to.
* @flags: User specified operational flags. * @flags: User specified operational flags.
* @kern: If the socket is kernel internal
* *
* Accept a new incoming connection. * Accept a new incoming connection.
* Returns 0 upon success, negative otherwise. * Returns 0 upon success, negative otherwise.
*/ */
static int llc_ui_accept(struct socket *sock, struct socket *newsock, int flags) static int llc_ui_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *sk = sock->sk, *newsk; struct sock *sk = sock->sk, *newsk;
struct llc_sock *llc, *newllc; struct llc_sock *llc, *newllc;

View File

@ -765,7 +765,8 @@ static int nr_connect(struct socket *sock, struct sockaddr *uaddr,
return err; return err;
} }
static int nr_accept(struct socket *sock, struct socket *newsock, int flags) static int nr_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sk_buff *skb; struct sk_buff *skb;
struct sock *newsk; struct sock *newsk;

View File

@ -441,7 +441,7 @@ struct sock *nfc_llcp_accept_dequeue(struct sock *parent,
} }
static int llcp_sock_accept(struct socket *sock, struct socket *newsock, static int llcp_sock_accept(struct socket *sock, struct socket *newsock,
int flags) int flags, bool kern)
{ {
DECLARE_WAITQUEUE(wait, current); DECLARE_WAITQUEUE(wait, current);
struct sock *sk = sock->sk, *new_sk; struct sock *sk = sock->sk, *new_sk;

View File

@ -772,7 +772,8 @@ static void pep_sock_close(struct sock *sk, long timeout)
sock_put(sk); sock_put(sk);
} }
static struct sock *pep_sock_accept(struct sock *sk, int flags, int *errp) static struct sock *pep_sock_accept(struct sock *sk, int flags, int *errp,
bool kern)
{ {
struct pep_sock *pn = pep_sk(sk), *newpn; struct pep_sock *pn = pep_sk(sk), *newpn;
struct sock *newsk = NULL; struct sock *newsk = NULL;
@ -846,7 +847,8 @@ static struct sock *pep_sock_accept(struct sock *sk, int flags, int *errp)
} }
/* Create a new to-be-accepted sock */ /* Create a new to-be-accepted sock */
newsk = sk_alloc(sock_net(sk), PF_PHONET, GFP_KERNEL, sk->sk_prot, 0); newsk = sk_alloc(sock_net(sk), PF_PHONET, GFP_KERNEL, sk->sk_prot,
kern);
if (!newsk) { if (!newsk) {
pep_reject_conn(sk, skb, PN_PIPE_ERR_OVERLOAD, GFP_KERNEL); pep_reject_conn(sk, skb, PN_PIPE_ERR_OVERLOAD, GFP_KERNEL);
err = -ENOBUFS; err = -ENOBUFS;

View File

@ -305,7 +305,7 @@ static int pn_socket_connect(struct socket *sock, struct sockaddr *addr,
} }
static int pn_socket_accept(struct socket *sock, struct socket *newsock, static int pn_socket_accept(struct socket *sock, struct socket *newsock,
int flags) int flags, bool kern)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct sock *newsk; struct sock *newsk;
@ -314,7 +314,7 @@ static int pn_socket_accept(struct socket *sock, struct socket *newsock,
if (unlikely(sk->sk_state != TCP_LISTEN)) if (unlikely(sk->sk_state != TCP_LISTEN))
return -EINVAL; return -EINVAL;
newsk = sk->sk_prot->accept(sk, flags, &err); newsk = sk->sk_prot->accept(sk, flags, &err, kern);
if (!newsk) if (!newsk)
return err; return err;

View File

@ -133,7 +133,7 @@ int rds_tcp_accept_one(struct socket *sock)
new_sock->type = sock->type; new_sock->type = sock->type;
new_sock->ops = sock->ops; new_sock->ops = sock->ops;
ret = sock->ops->accept(sock, new_sock, O_NONBLOCK); ret = sock->ops->accept(sock, new_sock, O_NONBLOCK, true);
if (ret < 0) if (ret < 0)
goto out; goto out;

View File

@ -871,7 +871,8 @@ static int rose_connect(struct socket *sock, struct sockaddr *uaddr, int addr_le
return err; return err;
} }
static int rose_accept(struct socket *sock, struct socket *newsock, int flags) static int rose_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sk_buff *skb; struct sk_buff *skb;
struct sock *newsk; struct sock *newsk;

View File

@ -640,14 +640,15 @@ static sctp_scope_t sctp_v6_scope(union sctp_addr *addr)
/* Create and initialize a new sk for the socket to be returned by accept(). */ /* Create and initialize a new sk for the socket to be returned by accept(). */
static struct sock *sctp_v6_create_accept_sk(struct sock *sk, static struct sock *sctp_v6_create_accept_sk(struct sock *sk,
struct sctp_association *asoc) struct sctp_association *asoc,
bool kern)
{ {
struct sock *newsk; struct sock *newsk;
struct ipv6_pinfo *newnp, *np = inet6_sk(sk); struct ipv6_pinfo *newnp, *np = inet6_sk(sk);
struct sctp6_sock *newsctp6sk; struct sctp6_sock *newsctp6sk;
struct ipv6_txoptions *opt; struct ipv6_txoptions *opt;
newsk = sk_alloc(sock_net(sk), PF_INET6, GFP_KERNEL, sk->sk_prot, 0); newsk = sk_alloc(sock_net(sk), PF_INET6, GFP_KERNEL, sk->sk_prot, kern);
if (!newsk) if (!newsk)
goto out; goto out;

View File

@ -575,10 +575,11 @@ static int sctp_v4_is_ce(const struct sk_buff *skb)
/* Create and initialize a new sk for the socket returned by accept(). */ /* Create and initialize a new sk for the socket returned by accept(). */
static struct sock *sctp_v4_create_accept_sk(struct sock *sk, static struct sock *sctp_v4_create_accept_sk(struct sock *sk,
struct sctp_association *asoc) struct sctp_association *asoc,
bool kern)
{ {
struct sock *newsk = sk_alloc(sock_net(sk), PF_INET, GFP_KERNEL, struct sock *newsk = sk_alloc(sock_net(sk), PF_INET, GFP_KERNEL,
sk->sk_prot, 0); sk->sk_prot, kern);
struct inet_sock *newinet; struct inet_sock *newinet;
if (!newsk) if (!newsk)

View File

@ -4116,7 +4116,7 @@ static int sctp_disconnect(struct sock *sk, int flags)
* descriptor will be returned from accept() to represent the newly * descriptor will be returned from accept() to represent the newly
* formed association. * formed association.
*/ */
static struct sock *sctp_accept(struct sock *sk, int flags, int *err) static struct sock *sctp_accept(struct sock *sk, int flags, int *err, bool kern)
{ {
struct sctp_sock *sp; struct sctp_sock *sp;
struct sctp_endpoint *ep; struct sctp_endpoint *ep;
@ -4151,7 +4151,7 @@ static struct sock *sctp_accept(struct sock *sk, int flags, int *err)
*/ */
asoc = list_entry(ep->asocs.next, struct sctp_association, asocs); asoc = list_entry(ep->asocs.next, struct sctp_association, asocs);
newsk = sp->pf->create_accept_sk(sk, asoc); newsk = sp->pf->create_accept_sk(sk, asoc, kern);
if (!newsk) { if (!newsk) {
error = -ENOMEM; error = -ENOMEM;
goto out; goto out;

View File

@ -944,7 +944,7 @@ static int smc_listen(struct socket *sock, int backlog)
} }
static int smc_accept(struct socket *sock, struct socket *new_sock, static int smc_accept(struct socket *sock, struct socket *new_sock,
int flags) int flags, bool kern)
{ {
struct sock *sk = sock->sk, *nsk; struct sock *sk = sock->sk, *nsk;
DECLARE_WAITQUEUE(wait, current); DECLARE_WAITQUEUE(wait, current);

View File

@ -1506,7 +1506,7 @@ SYSCALL_DEFINE4(accept4, int, fd, struct sockaddr __user *, upeer_sockaddr,
if (err) if (err)
goto out_fd; goto out_fd;
err = sock->ops->accept(sock, newsock, sock->file->f_flags); err = sock->ops->accept(sock, newsock, sock->file->f_flags, false);
if (err < 0) if (err < 0)
goto out_fd; goto out_fd;
@ -3239,7 +3239,7 @@ int kernel_accept(struct socket *sock, struct socket **newsock, int flags)
if (err < 0) if (err < 0)
goto done; goto done;
err = sock->ops->accept(sock, *newsock, flags); err = sock->ops->accept(sock, *newsock, flags, true);
if (err < 0) { if (err < 0) {
sock_release(*newsock); sock_release(*newsock);
*newsock = NULL; *newsock = NULL;

View File

@ -115,7 +115,8 @@ static void tipc_data_ready(struct sock *sk);
static void tipc_write_space(struct sock *sk); static void tipc_write_space(struct sock *sk);
static void tipc_sock_destruct(struct sock *sk); static void tipc_sock_destruct(struct sock *sk);
static int tipc_release(struct socket *sock); static int tipc_release(struct socket *sock);
static int tipc_accept(struct socket *sock, struct socket *new_sock, int flags); static int tipc_accept(struct socket *sock, struct socket *new_sock, int flags,
bool kern);
static void tipc_sk_timeout(unsigned long data); static void tipc_sk_timeout(unsigned long data);
static int tipc_sk_publish(struct tipc_sock *tsk, uint scope, static int tipc_sk_publish(struct tipc_sock *tsk, uint scope,
struct tipc_name_seq const *seq); struct tipc_name_seq const *seq);
@ -2029,7 +2030,8 @@ static int tipc_wait_for_accept(struct socket *sock, long timeo)
* *
* Returns 0 on success, errno otherwise * Returns 0 on success, errno otherwise
*/ */
static int tipc_accept(struct socket *sock, struct socket *new_sock, int flags) static int tipc_accept(struct socket *sock, struct socket *new_sock, int flags,
bool kern)
{ {
struct sock *new_sk, *sk = sock->sk; struct sock *new_sk, *sk = sock->sk;
struct sk_buff *buf; struct sk_buff *buf;
@ -2051,7 +2053,7 @@ static int tipc_accept(struct socket *sock, struct socket *new_sock, int flags)
buf = skb_peek(&sk->sk_receive_queue); buf = skb_peek(&sk->sk_receive_queue);
res = tipc_sk_create(sock_net(sock->sk), new_sock, 0, 0); res = tipc_sk_create(sock_net(sock->sk), new_sock, 0, kern);
if (res) if (res)
goto exit; goto exit;
security_sk_clone(sock->sk, new_sock->sk); security_sk_clone(sock->sk, new_sock->sk);

View File

@ -636,7 +636,7 @@ static int unix_bind(struct socket *, struct sockaddr *, int);
static int unix_stream_connect(struct socket *, struct sockaddr *, static int unix_stream_connect(struct socket *, struct sockaddr *,
int addr_len, int flags); int addr_len, int flags);
static int unix_socketpair(struct socket *, struct socket *); static int unix_socketpair(struct socket *, struct socket *);
static int unix_accept(struct socket *, struct socket *, int); static int unix_accept(struct socket *, struct socket *, int, bool);
static int unix_getname(struct socket *, struct sockaddr *, int *, int); static int unix_getname(struct socket *, struct sockaddr *, int *, int);
static unsigned int unix_poll(struct file *, struct socket *, poll_table *); static unsigned int unix_poll(struct file *, struct socket *, poll_table *);
static unsigned int unix_dgram_poll(struct file *, struct socket *, static unsigned int unix_dgram_poll(struct file *, struct socket *,
@ -1402,7 +1402,8 @@ static void unix_sock_inherit_flags(const struct socket *old,
set_bit(SOCK_PASSSEC, &new->flags); set_bit(SOCK_PASSSEC, &new->flags);
} }
static int unix_accept(struct socket *sock, struct socket *newsock, int flags) static int unix_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct sock *tsk; struct sock *tsk;

View File

@ -1250,7 +1250,8 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
return err; return err;
} }
static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *listener; struct sock *listener;
int err; int err;

View File

@ -852,7 +852,8 @@ static int x25_wait_for_data(struct sock *sk, long timeout)
return rc; return rc;
} }
static int x25_accept(struct socket *sock, struct socket *newsock, int flags) static int x25_accept(struct socket *sock, struct socket *newsock, int flags,
bool kern)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct sock *newsk; struct sock *newsk;