diff mbox series

[1/3] kunit: remove va_format from kunit_assert

Message ID 20220125210011.3817742-2-dlatypov@google.com
State Accepted
Commit 6419abb80e82c603bbec6d7f5af6c2f79fa5c4ae
Headers show
Series kunit: further reduce stack usage of asserts | expand

Commit Message

Daniel Latypov Jan. 25, 2022, 9 p.m. UTC
The concern is that having a lot of redundant fields in kunit_assert can
blow up stack usage if the compiler doesn't optimize them away [1].

The comment on this field implies that it was meant to be initialized
when the expect/assert was declared, but this only happens when we run
kunit_do_failed_assertion().

We don't need to access it outside of that function, so move it out of
the struct and make it a local variable there.

This change also takes the chance to reduce the number of macros by
inlining the now simplified KUNIT_INIT_ASSERT_STRUCT() macro.

[1] https://groups.google.com/g/kunit-dev/c/i3fZXgvBrfA/m/VULQg1z6BAAJ

Signed-off-by: Daniel Latypov <dlatypov@google.com>
---
 include/kunit/assert.h | 43 +++++++++++++-----------------------------
 lib/kunit/assert.c     | 27 +++++++++++++++-----------
 lib/kunit/test.c       | 12 +++++++-----
 3 files changed, 36 insertions(+), 46 deletions(-)

Comments

David Gow Jan. 27, 2022, 3:34 a.m. UTC | #1
On Wed, Jan 26, 2022 at 5:00 AM Daniel Latypov <dlatypov@google.com> wrote:
>
> The concern is that having a lot of redundant fields in kunit_assert can
> blow up stack usage if the compiler doesn't optimize them away [1].
>
> The comment on this field implies that it was meant to be initialized
> when the expect/assert was declared, but this only happens when we run
> kunit_do_failed_assertion().
>
> We don't need to access it outside of that function, so move it out of
> the struct and make it a local variable there.
>
> This change also takes the chance to reduce the number of macros by
> inlining the now simplified KUNIT_INIT_ASSERT_STRUCT() macro.
>
> [1] https://groups.google.com/g/kunit-dev/c/i3fZXgvBrfA/m/VULQg1z6BAAJ
>
> Signed-off-by: Daniel Latypov <dlatypov@google.com>
> ---

Looks good to me. I particularly like the removal of the
KUNIT_INIT_ASSERT_STRUCT() macro. I do feel that there's not much
point having a kunit_assert struct at all now that it's just one
function pointer, but the indirection is probably still useful enough
given that things are still changing, and function pointers are always
a little ugly.

Reviewed-by: David Gow <davidgow@google.com>

-- David

>  include/kunit/assert.h | 43 +++++++++++++-----------------------------
>  lib/kunit/assert.c     | 27 +++++++++++++++-----------
>  lib/kunit/test.c       | 12 +++++++-----
>  3 files changed, 36 insertions(+), 46 deletions(-)
>
> diff --git a/include/kunit/assert.h b/include/kunit/assert.h
> index f2b3ae5cc2de..0b3704db54b6 100644
> --- a/include/kunit/assert.h
> +++ b/include/kunit/assert.h
> @@ -42,44 +42,21 @@ struct kunit_loc {
>
>  /**
>   * struct kunit_assert - Data for printing a failed assertion or expectation.
> - * @message: an optional message to provide additional context.
>   * @format: a function which formats the data in this kunit_assert to a string.
>   *
>   * Represents a failed expectation/assertion. Contains all the data necessary to
>   * format a string to a user reporting the failure.
>   */
>  struct kunit_assert {
> -       struct va_format message;
>         void (*format)(const struct kunit_assert *assert,
> +                      const struct va_format *message,
>                        struct string_stream *stream);
>  };
>
> -/**
> - * KUNIT_INIT_VA_FMT_NULL - Default initializer for struct va_format.
> - *
> - * Used inside a struct initialization block to initialize struct va_format to
> - * default values where fmt and va are null.
> - */
> -#define KUNIT_INIT_VA_FMT_NULL { .fmt = NULL, .va = NULL }
> -
> -/**
> - * KUNIT_INIT_ASSERT_STRUCT() - Initializer for a &struct kunit_assert.
> - * @fmt: The formatting function which builds a string out of this kunit_assert.
> - *
> - * The base initializer for a &struct kunit_assert.
> - */
> -#define KUNIT_INIT_ASSERT_STRUCT(fmt) {                                               \
> -       .message = KUNIT_INIT_VA_FMT_NULL,                                     \
> -       .format = fmt                                                          \
> -}
> -
>  void kunit_assert_prologue(const struct kunit_loc *loc,
>                            enum kunit_assert_type type,
>                            struct string_stream *stream);
>
> -void kunit_assert_print_msg(const struct kunit_assert *assert,
> -                           struct string_stream *stream);
> -
>  /**
>   * struct kunit_fail_assert - Represents a plain fail expectation/assertion.
>   * @assert: The parent of this type.
> @@ -91,6 +68,7 @@ struct kunit_fail_assert {
>  };
>
>  void kunit_fail_assert_format(const struct kunit_assert *assert,
> +                             const struct va_format *message,
>                               struct string_stream *stream);
>
>  /**
> @@ -100,7 +78,7 @@ void kunit_fail_assert_format(const struct kunit_assert *assert,
>   * KUNIT_EXPECT_* and KUNIT_ASSERT_* macros.
>   */
>  #define KUNIT_INIT_FAIL_ASSERT_STRUCT {                                        \
> -       .assert = KUNIT_INIT_ASSERT_STRUCT(kunit_fail_assert_format)    \
> +       .assert = { .format = kunit_fail_assert_format },               \
>  }
>
>  /**
> @@ -120,6 +98,7 @@ struct kunit_unary_assert {
>  };
>
>  void kunit_unary_assert_format(const struct kunit_assert *assert,
> +                              const struct va_format *message,
>                                struct string_stream *stream);
>
>  /**
> @@ -131,7 +110,7 @@ void kunit_unary_assert_format(const struct kunit_assert *assert,
>   * KUNIT_EXPECT_* and KUNIT_ASSERT_* macros.
>   */
>  #define KUNIT_INIT_UNARY_ASSERT_STRUCT(cond, expect_true) {                   \
> -       .assert = KUNIT_INIT_ASSERT_STRUCT(kunit_unary_assert_format),         \
> +       .assert = { .format = kunit_unary_assert_format },                     \
>         .condition = cond,                                                     \
>         .expected_true = expect_true                                           \
>  }
> @@ -153,6 +132,7 @@ struct kunit_ptr_not_err_assert {
>  };
>
>  void kunit_ptr_not_err_assert_format(const struct kunit_assert *assert,
> +                                    const struct va_format *message,
>                                      struct string_stream *stream);
>
>  /**
> @@ -165,7 +145,7 @@ void kunit_ptr_not_err_assert_format(const struct kunit_assert *assert,
>   * KUNIT_EXPECT_* and KUNIT_ASSERT_* macros.
>   */
>  #define KUNIT_INIT_PTR_NOT_ERR_STRUCT(txt, val) {                             \
> -       .assert = KUNIT_INIT_ASSERT_STRUCT(kunit_ptr_not_err_assert_format),   \
> +       .assert = { .format = kunit_ptr_not_err_assert_format },               \
>         .text = txt,                                                           \
>         .value = val                                                           \
>  }
> @@ -194,6 +174,7 @@ struct kunit_binary_assert {
>  };
>
>  void kunit_binary_assert_format(const struct kunit_assert *assert,
> +                               const struct va_format *message,
>                                 struct string_stream *stream);
>
>  /**
> @@ -213,7 +194,7 @@ void kunit_binary_assert_format(const struct kunit_assert *assert,
>                                         left_val,                              \
>                                         right_str,                             \
>                                         right_val) {                           \
> -       .assert = KUNIT_INIT_ASSERT_STRUCT(kunit_binary_assert_format),        \
> +       .assert = { .format = kunit_binary_assert_format },                    \
>         .operation = op_str,                                                   \
>         .left_text = left_str,                                                 \
>         .left_value = left_val,                                                \
> @@ -245,6 +226,7 @@ struct kunit_binary_ptr_assert {
>  };
>
>  void kunit_binary_ptr_assert_format(const struct kunit_assert *assert,
> +                                   const struct va_format *message,
>                                     struct string_stream *stream);
>
>  /**
> @@ -265,7 +247,7 @@ void kunit_binary_ptr_assert_format(const struct kunit_assert *assert,
>                                             left_val,                          \
>                                             right_str,                         \
>                                             right_val) {                       \
> -       .assert = KUNIT_INIT_ASSERT_STRUCT(kunit_binary_ptr_assert_format),    \
> +       .assert = { .format = kunit_binary_ptr_assert_format },                \
>         .operation = op_str,                                                   \
>         .left_text = left_str,                                                 \
>         .left_value = left_val,                                                \
> @@ -297,6 +279,7 @@ struct kunit_binary_str_assert {
>  };
>
>  void kunit_binary_str_assert_format(const struct kunit_assert *assert,
> +                                   const struct va_format *message,
>                                     struct string_stream *stream);
>
>  /**
> @@ -316,7 +299,7 @@ void kunit_binary_str_assert_format(const struct kunit_assert *assert,
>                                             left_val,                          \
>                                             right_str,                         \
>                                             right_val) {                       \
> -       .assert = KUNIT_INIT_ASSERT_STRUCT(kunit_binary_str_assert_format),    \
> +       .assert = { .format = kunit_binary_str_assert_format },                \
>         .operation = op_str,                                                   \
>         .left_text = left_str,                                                 \
>         .left_value = left_val,                                                \
> diff --git a/lib/kunit/assert.c b/lib/kunit/assert.c
> index 9f4492a8e24e..c9c7ee0dfafa 100644
> --- a/lib/kunit/assert.c
> +++ b/lib/kunit/assert.c
> @@ -30,22 +30,23 @@ void kunit_assert_prologue(const struct kunit_loc *loc,
>  }
>  EXPORT_SYMBOL_GPL(kunit_assert_prologue);
>
> -void kunit_assert_print_msg(const struct kunit_assert *assert,
> -                           struct string_stream *stream)
> +static void kunit_assert_print_msg(const struct va_format *message,
> +                                  struct string_stream *stream)
>  {
> -       if (assert->message.fmt)
> -               string_stream_add(stream, "\n%pV", &assert->message);
> +       if (message->fmt)
> +               string_stream_add(stream, "\n%pV", message);
>  }
> -EXPORT_SYMBOL_GPL(kunit_assert_print_msg);
>
>  void kunit_fail_assert_format(const struct kunit_assert *assert,
> +                             const struct va_format *message,
>                               struct string_stream *stream)
>  {
> -       string_stream_add(stream, "%pV", &assert->message);
> +       string_stream_add(stream, "%pV", message);
>  }
>  EXPORT_SYMBOL_GPL(kunit_fail_assert_format);
>
>  void kunit_unary_assert_format(const struct kunit_assert *assert,
> +                              const struct va_format *message,
>                                struct string_stream *stream)
>  {
>         struct kunit_unary_assert *unary_assert;
> @@ -60,11 +61,12 @@ void kunit_unary_assert_format(const struct kunit_assert *assert,
>                 string_stream_add(stream,
>                                   KUNIT_SUBTEST_INDENT "Expected %s to be false, but is true\n",
>                                   unary_assert->condition);
> -       kunit_assert_print_msg(assert, stream);
> +       kunit_assert_print_msg(message, stream);
>  }
>  EXPORT_SYMBOL_GPL(kunit_unary_assert_format);
>
>  void kunit_ptr_not_err_assert_format(const struct kunit_assert *assert,
> +                                    const struct va_format *message,
>                                      struct string_stream *stream)
>  {
>         struct kunit_ptr_not_err_assert *ptr_assert;
> @@ -82,7 +84,7 @@ void kunit_ptr_not_err_assert_format(const struct kunit_assert *assert,
>                                   ptr_assert->text,
>                                   PTR_ERR(ptr_assert->value));
>         }
> -       kunit_assert_print_msg(assert, stream);
> +       kunit_assert_print_msg(message, stream);
>  }
>  EXPORT_SYMBOL_GPL(kunit_ptr_not_err_assert_format);
>
> @@ -110,6 +112,7 @@ static bool is_literal(struct kunit *test, const char *text, long long value,
>  }
>
>  void kunit_binary_assert_format(const struct kunit_assert *assert,
> +                               const struct va_format *message,
>                                 struct string_stream *stream)
>  {
>         struct kunit_binary_assert *binary_assert;
> @@ -132,11 +135,12 @@ void kunit_binary_assert_format(const struct kunit_assert *assert,
>                 string_stream_add(stream, KUNIT_SUBSUBTEST_INDENT "%s == %lld",
>                                   binary_assert->right_text,
>                                   binary_assert->right_value);
> -       kunit_assert_print_msg(assert, stream);
> +       kunit_assert_print_msg(message, stream);
>  }
>  EXPORT_SYMBOL_GPL(kunit_binary_assert_format);
>
>  void kunit_binary_ptr_assert_format(const struct kunit_assert *assert,
> +                                   const struct va_format *message,
>                                     struct string_stream *stream)
>  {
>         struct kunit_binary_ptr_assert *binary_assert;
> @@ -155,7 +159,7 @@ void kunit_binary_ptr_assert_format(const struct kunit_assert *assert,
>         string_stream_add(stream, KUNIT_SUBSUBTEST_INDENT "%s == %px",
>                           binary_assert->right_text,
>                           binary_assert->right_value);
> -       kunit_assert_print_msg(assert, stream);
> +       kunit_assert_print_msg(message, stream);
>  }
>  EXPORT_SYMBOL_GPL(kunit_binary_ptr_assert_format);
>
> @@ -176,6 +180,7 @@ static bool is_str_literal(const char *text, const char *value)
>  }
>
>  void kunit_binary_str_assert_format(const struct kunit_assert *assert,
> +                                   const struct va_format *message,
>                                     struct string_stream *stream)
>  {
>         struct kunit_binary_str_assert *binary_assert;
> @@ -196,6 +201,6 @@ void kunit_binary_str_assert_format(const struct kunit_assert *assert,
>                 string_stream_add(stream, KUNIT_SUBSUBTEST_INDENT "%s == \"%s\"",
>                                   binary_assert->right_text,
>                                   binary_assert->right_value);
> -       kunit_assert_print_msg(assert, stream);
> +       kunit_assert_print_msg(message, stream);
>  }
>  EXPORT_SYMBOL_GPL(kunit_binary_str_assert_format);
> diff --git a/lib/kunit/test.c b/lib/kunit/test.c
> index 7dec3248562f..3bca3bf5c15b 100644
> --- a/lib/kunit/test.c
> +++ b/lib/kunit/test.c
> @@ -241,7 +241,8 @@ static void kunit_print_string_stream(struct kunit *test,
>  }
>
>  static void kunit_fail(struct kunit *test, const struct kunit_loc *loc,
> -                      enum kunit_assert_type type, struct kunit_assert *assert)
> +                      enum kunit_assert_type type, struct kunit_assert *assert,
> +                      const struct va_format *message)
>  {
>         struct string_stream *stream;
>
> @@ -257,7 +258,7 @@ static void kunit_fail(struct kunit *test, const struct kunit_loc *loc,
>         }
>
>         kunit_assert_prologue(loc, type, stream);
> -       assert->format(assert, stream);
> +       assert->format(assert, message, stream);
>
>         kunit_print_string_stream(test, stream);
>
> @@ -284,12 +285,13 @@ void kunit_do_failed_assertion(struct kunit *test,
>                                const char *fmt, ...)
>  {
>         va_list args;
> +       struct va_format message;
>         va_start(args, fmt);
>
> -       assert->message.fmt = fmt;
> -       assert->message.va = &args;
> +       message.fmt = fmt;
> +       message.va = &args;
>
> -       kunit_fail(test, loc, type, assert);
> +       kunit_fail(test, loc, type, assert, &message);
>
>         va_end(args);
>
> --
> 2.35.0.rc2.247.g8bbb082509-goog
>
Brendan Higgins Jan. 27, 2022, 9:22 p.m. UTC | #2
On Tue, Jan 25, 2022 at 4:00 PM Daniel Latypov <dlatypov@google.com> wrote:
>
> The concern is that having a lot of redundant fields in kunit_assert can
> blow up stack usage if the compiler doesn't optimize them away [1].
>
> The comment on this field implies that it was meant to be initialized
> when the expect/assert was declared, but this only happens when we run
> kunit_do_failed_assertion().
>
> We don't need to access it outside of that function, so move it out of
> the struct and make it a local variable there.
>
> This change also takes the chance to reduce the number of macros by
> inlining the now simplified KUNIT_INIT_ASSERT_STRUCT() macro.
>
> [1] https://groups.google.com/g/kunit-dev/c/i3fZXgvBrfA/m/VULQg1z6BAAJ
>
> Signed-off-by: Daniel Latypov <dlatypov@google.com>

Reviewed-by: Brendan Higgins <brendanhiggins@google.com>
diff mbox series

Patch

diff --git a/include/kunit/assert.h b/include/kunit/assert.h
index f2b3ae5cc2de..0b3704db54b6 100644
--- a/include/kunit/assert.h
+++ b/include/kunit/assert.h
@@ -42,44 +42,21 @@  struct kunit_loc {
 
 /**
  * struct kunit_assert - Data for printing a failed assertion or expectation.
- * @message: an optional message to provide additional context.
  * @format: a function which formats the data in this kunit_assert to a string.
  *
  * Represents a failed expectation/assertion. Contains all the data necessary to
  * format a string to a user reporting the failure.
  */
 struct kunit_assert {
-	struct va_format message;
 	void (*format)(const struct kunit_assert *assert,
+		       const struct va_format *message,
 		       struct string_stream *stream);
 };
 
-/**
- * KUNIT_INIT_VA_FMT_NULL - Default initializer for struct va_format.
- *
- * Used inside a struct initialization block to initialize struct va_format to
- * default values where fmt and va are null.
- */
-#define KUNIT_INIT_VA_FMT_NULL { .fmt = NULL, .va = NULL }
-
-/**
- * KUNIT_INIT_ASSERT_STRUCT() - Initializer for a &struct kunit_assert.
- * @fmt: The formatting function which builds a string out of this kunit_assert.
- *
- * The base initializer for a &struct kunit_assert.
- */
-#define KUNIT_INIT_ASSERT_STRUCT(fmt) {					       \
-	.message = KUNIT_INIT_VA_FMT_NULL,				       \
-	.format = fmt							       \
-}
-
 void kunit_assert_prologue(const struct kunit_loc *loc,
 			   enum kunit_assert_type type,
 			   struct string_stream *stream);
 
-void kunit_assert_print_msg(const struct kunit_assert *assert,
-			    struct string_stream *stream);
-
 /**
  * struct kunit_fail_assert - Represents a plain fail expectation/assertion.
  * @assert: The parent of this type.
@@ -91,6 +68,7 @@  struct kunit_fail_assert {
 };
 
 void kunit_fail_assert_format(const struct kunit_assert *assert,
+			      const struct va_format *message,
 			      struct string_stream *stream);
 
 /**
@@ -100,7 +78,7 @@  void kunit_fail_assert_format(const struct kunit_assert *assert,
  * KUNIT_EXPECT_* and KUNIT_ASSERT_* macros.
  */
 #define KUNIT_INIT_FAIL_ASSERT_STRUCT {					\
-	.assert = KUNIT_INIT_ASSERT_STRUCT(kunit_fail_assert_format)	\
+	.assert = { .format = kunit_fail_assert_format },		\
 }
 
 /**
@@ -120,6 +98,7 @@  struct kunit_unary_assert {
 };
 
 void kunit_unary_assert_format(const struct kunit_assert *assert,
+			       const struct va_format *message,
 			       struct string_stream *stream);
 
 /**
@@ -131,7 +110,7 @@  void kunit_unary_assert_format(const struct kunit_assert *assert,
  * KUNIT_EXPECT_* and KUNIT_ASSERT_* macros.
  */
 #define KUNIT_INIT_UNARY_ASSERT_STRUCT(cond, expect_true) {		       \
-	.assert = KUNIT_INIT_ASSERT_STRUCT(kunit_unary_assert_format),	       \
+	.assert = { .format = kunit_unary_assert_format },		       \
 	.condition = cond,						       \
 	.expected_true = expect_true					       \
 }
@@ -153,6 +132,7 @@  struct kunit_ptr_not_err_assert {
 };
 
 void kunit_ptr_not_err_assert_format(const struct kunit_assert *assert,
+				     const struct va_format *message,
 				     struct string_stream *stream);
 
 /**
@@ -165,7 +145,7 @@  void kunit_ptr_not_err_assert_format(const struct kunit_assert *assert,
  * KUNIT_EXPECT_* and KUNIT_ASSERT_* macros.
  */
 #define KUNIT_INIT_PTR_NOT_ERR_STRUCT(txt, val) {			       \
-	.assert = KUNIT_INIT_ASSERT_STRUCT(kunit_ptr_not_err_assert_format),   \
+	.assert = { .format = kunit_ptr_not_err_assert_format },	       \
 	.text = txt,							       \
 	.value = val							       \
 }
@@ -194,6 +174,7 @@  struct kunit_binary_assert {
 };
 
 void kunit_binary_assert_format(const struct kunit_assert *assert,
+				const struct va_format *message,
 				struct string_stream *stream);
 
 /**
@@ -213,7 +194,7 @@  void kunit_binary_assert_format(const struct kunit_assert *assert,
 					left_val,			       \
 					right_str,			       \
 					right_val) {			       \
-	.assert = KUNIT_INIT_ASSERT_STRUCT(kunit_binary_assert_format),	       \
+	.assert = { .format = kunit_binary_assert_format },		       \
 	.operation = op_str,						       \
 	.left_text = left_str,						       \
 	.left_value = left_val,						       \
@@ -245,6 +226,7 @@  struct kunit_binary_ptr_assert {
 };
 
 void kunit_binary_ptr_assert_format(const struct kunit_assert *assert,
+				    const struct va_format *message,
 				    struct string_stream *stream);
 
 /**
@@ -265,7 +247,7 @@  void kunit_binary_ptr_assert_format(const struct kunit_assert *assert,
 					    left_val,			       \
 					    right_str,			       \
 					    right_val) {		       \
-	.assert = KUNIT_INIT_ASSERT_STRUCT(kunit_binary_ptr_assert_format),    \
+	.assert = { .format = kunit_binary_ptr_assert_format },		       \
 	.operation = op_str,						       \
 	.left_text = left_str,						       \
 	.left_value = left_val,						       \
@@ -297,6 +279,7 @@  struct kunit_binary_str_assert {
 };
 
 void kunit_binary_str_assert_format(const struct kunit_assert *assert,
+				    const struct va_format *message,
 				    struct string_stream *stream);
 
 /**
@@ -316,7 +299,7 @@  void kunit_binary_str_assert_format(const struct kunit_assert *assert,
 					    left_val,			       \
 					    right_str,			       \
 					    right_val) {		       \
-	.assert = KUNIT_INIT_ASSERT_STRUCT(kunit_binary_str_assert_format),    \
+	.assert = { .format = kunit_binary_str_assert_format },		       \
 	.operation = op_str,						       \
 	.left_text = left_str,						       \
 	.left_value = left_val,						       \
diff --git a/lib/kunit/assert.c b/lib/kunit/assert.c
index 9f4492a8e24e..c9c7ee0dfafa 100644
--- a/lib/kunit/assert.c
+++ b/lib/kunit/assert.c
@@ -30,22 +30,23 @@  void kunit_assert_prologue(const struct kunit_loc *loc,
 }
 EXPORT_SYMBOL_GPL(kunit_assert_prologue);
 
-void kunit_assert_print_msg(const struct kunit_assert *assert,
-			    struct string_stream *stream)
+static void kunit_assert_print_msg(const struct va_format *message,
+				   struct string_stream *stream)
 {
-	if (assert->message.fmt)
-		string_stream_add(stream, "\n%pV", &assert->message);
+	if (message->fmt)
+		string_stream_add(stream, "\n%pV", message);
 }
-EXPORT_SYMBOL_GPL(kunit_assert_print_msg);
 
 void kunit_fail_assert_format(const struct kunit_assert *assert,
+			      const struct va_format *message,
 			      struct string_stream *stream)
 {
-	string_stream_add(stream, "%pV", &assert->message);
+	string_stream_add(stream, "%pV", message);
 }
 EXPORT_SYMBOL_GPL(kunit_fail_assert_format);
 
 void kunit_unary_assert_format(const struct kunit_assert *assert,
+			       const struct va_format *message,
 			       struct string_stream *stream)
 {
 	struct kunit_unary_assert *unary_assert;
@@ -60,11 +61,12 @@  void kunit_unary_assert_format(const struct kunit_assert *assert,
 		string_stream_add(stream,
 				  KUNIT_SUBTEST_INDENT "Expected %s to be false, but is true\n",
 				  unary_assert->condition);
-	kunit_assert_print_msg(assert, stream);
+	kunit_assert_print_msg(message, stream);
 }
 EXPORT_SYMBOL_GPL(kunit_unary_assert_format);
 
 void kunit_ptr_not_err_assert_format(const struct kunit_assert *assert,
+				     const struct va_format *message,
 				     struct string_stream *stream)
 {
 	struct kunit_ptr_not_err_assert *ptr_assert;
@@ -82,7 +84,7 @@  void kunit_ptr_not_err_assert_format(const struct kunit_assert *assert,
 				  ptr_assert->text,
 				  PTR_ERR(ptr_assert->value));
 	}
-	kunit_assert_print_msg(assert, stream);
+	kunit_assert_print_msg(message, stream);
 }
 EXPORT_SYMBOL_GPL(kunit_ptr_not_err_assert_format);
 
@@ -110,6 +112,7 @@  static bool is_literal(struct kunit *test, const char *text, long long value,
 }
 
 void kunit_binary_assert_format(const struct kunit_assert *assert,
+				const struct va_format *message,
 				struct string_stream *stream)
 {
 	struct kunit_binary_assert *binary_assert;
@@ -132,11 +135,12 @@  void kunit_binary_assert_format(const struct kunit_assert *assert,
 		string_stream_add(stream, KUNIT_SUBSUBTEST_INDENT "%s == %lld",
 				  binary_assert->right_text,
 				  binary_assert->right_value);
-	kunit_assert_print_msg(assert, stream);
+	kunit_assert_print_msg(message, stream);
 }
 EXPORT_SYMBOL_GPL(kunit_binary_assert_format);
 
 void kunit_binary_ptr_assert_format(const struct kunit_assert *assert,
+				    const struct va_format *message,
 				    struct string_stream *stream)
 {
 	struct kunit_binary_ptr_assert *binary_assert;
@@ -155,7 +159,7 @@  void kunit_binary_ptr_assert_format(const struct kunit_assert *assert,
 	string_stream_add(stream, KUNIT_SUBSUBTEST_INDENT "%s == %px",
 			  binary_assert->right_text,
 			  binary_assert->right_value);
-	kunit_assert_print_msg(assert, stream);
+	kunit_assert_print_msg(message, stream);
 }
 EXPORT_SYMBOL_GPL(kunit_binary_ptr_assert_format);
 
@@ -176,6 +180,7 @@  static bool is_str_literal(const char *text, const char *value)
 }
 
 void kunit_binary_str_assert_format(const struct kunit_assert *assert,
+				    const struct va_format *message,
 				    struct string_stream *stream)
 {
 	struct kunit_binary_str_assert *binary_assert;
@@ -196,6 +201,6 @@  void kunit_binary_str_assert_format(const struct kunit_assert *assert,
 		string_stream_add(stream, KUNIT_SUBSUBTEST_INDENT "%s == \"%s\"",
 				  binary_assert->right_text,
 				  binary_assert->right_value);
-	kunit_assert_print_msg(assert, stream);
+	kunit_assert_print_msg(message, stream);
 }
 EXPORT_SYMBOL_GPL(kunit_binary_str_assert_format);
diff --git a/lib/kunit/test.c b/lib/kunit/test.c
index 7dec3248562f..3bca3bf5c15b 100644
--- a/lib/kunit/test.c
+++ b/lib/kunit/test.c
@@ -241,7 +241,8 @@  static void kunit_print_string_stream(struct kunit *test,
 }
 
 static void kunit_fail(struct kunit *test, const struct kunit_loc *loc,
-		       enum kunit_assert_type type, struct kunit_assert *assert)
+		       enum kunit_assert_type type, struct kunit_assert *assert,
+		       const struct va_format *message)
 {
 	struct string_stream *stream;
 
@@ -257,7 +258,7 @@  static void kunit_fail(struct kunit *test, const struct kunit_loc *loc,
 	}
 
 	kunit_assert_prologue(loc, type, stream);
-	assert->format(assert, stream);
+	assert->format(assert, message, stream);
 
 	kunit_print_string_stream(test, stream);
 
@@ -284,12 +285,13 @@  void kunit_do_failed_assertion(struct kunit *test,
 			       const char *fmt, ...)
 {
 	va_list args;
+	struct va_format message;
 	va_start(args, fmt);
 
-	assert->message.fmt = fmt;
-	assert->message.va = &args;
+	message.fmt = fmt;
+	message.va = &args;
 
-	kunit_fail(test, loc, type, assert);
+	kunit_fail(test, loc, type, assert, &message);
 
 	va_end(args);