@@ -902,14 +902,6 @@ static int chtls_skb_copy_to_page_nocache(struct sock *sk,
return 0;
}
-/* Read TLS header to find content type and data length */
-static int tls_header_read(struct tls_hdr *thdr, struct iov_iter *from)
-{
- if (copy_from_iter(thdr, sizeof(*thdr), from) != sizeof(*thdr))
- return -EFAULT;
- return (__force int)cpu_to_be16(thdr->length);
-}
-
static int csk_mem_free(struct chtls_dev *cdev, struct sock *sk)
{
return (cdev->max_host_sndbuf - sk->sk_wmem_queued);
@@ -1022,15 +1014,20 @@ int chtls_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
goto wait_for_sndbuf;
if (is_tls_tx(csk) && !csk->tlshws.txleft) {
- struct tls_hdr hdr;
+ unsigned char record_type = TLS_RECORD_TYPE_DATA;
- recordsz = tls_header_read(&hdr, &msg->msg_iter);
- size -= TLS_HEADER_LENGTH;
- copied += TLS_HEADER_LENGTH;
+ if (unlikely(msg->msg_controllen)) {
+ err = tls_proccess_cmsg(sk, msg, &record_type);
+ if (err)
+ goto out_err;
+ }
+
+ recordsz = size;
csk->tlshws.txleft = recordsz;
- csk->tlshws.type = hdr.type;
+ csk->tlshws.type = record_type;
+
if (skb)
- ULP_SKB_CB(skb)->ulp.tls.type = hdr.type;
+ ULP_SKB_CB(skb)->ulp.tls.type = record_type;
}
if (!skb || (ULP_SKB_CB(skb)->flags & ULPCB_FLAG_NO_APPEND) ||
@@ -1521,6 +1518,22 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
}
}
}
+ /* Set record type if not already done. For a non-data record,
+ * do not proceed if record type could not be copied.
+ */
+ if (ULP_SKB_CB(skb)->flags & ULPCB_FLAG_TLS_HDR) {
+ struct tls_hdr *thdr = (struct tls_hdr *)skb->data;
+ int cerr = 0;
+
+ cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+ sizeof(thdr->type), &thdr->type);
+
+ if (cerr && thdr->type != TLS_RECORD_TYPE_DATA)
+ return -EIO;
+ /* don't send tls header, skip copy */
+ goto skip_copy;
+ }
+
if (skb_copy_datagram_msg(skb, offset, msg, avail)) {
if (!copied) {
copied = -EFAULT;
@@ -195,6 +195,7 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
return rc;
}
+EXPORT_SYMBOL(tls_proccess_cmsg);
int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
int flags)