diff mbox series

[mt76,3/5] wifi: mt76: Move RCU section in mt7996_mcu_add_rate_ctrl_fixed()

Message ID 20250605-mt7996-sleep-while-atomic-v1-3-d46d15f9203c@kernel.org
State New
Headers show
Series [mt76,1/5] wifi: mt76: Assume __mt76_connac_mcu_alloc_sta_req runs in atomic context | expand

Commit Message

Lorenzo Bianconi June 5, 2025, 11:14 a.m. UTC
Since mt7996_mcu_set_fixed_field() can't be executed in a RCU critical
section, move RCU section in mt7996_mcu_add_rate_ctrl_fixed() and run
mt7996_mcu_set_fixed_field() in non-atomic context. This is a
preliminary patch to fix a 'sleep while atomic' issue in
mt7996_mac_sta_rc_work().

Fixes: 0762bdd30279 ("wifi: mt76: mt7996: rework mt7996_mac_sta_rc_work to support MLO")
Signed-off-by: Lorenzo Bianconi <lorenzo@kernel.org>
---
 drivers/net/wireless/mediatek/mt76/mt7996/mcu.c | 86 ++++++++++++++++---------
 1 file changed, 57 insertions(+), 29 deletions(-)
diff mbox series

Patch

diff --git a/drivers/net/wireless/mediatek/mt76/mt7996/mcu.c b/drivers/net/wireless/mediatek/mt76/mt7996/mcu.c
index 33c61e795b734e84af42fdea5ba33975e3e3365a..742497ba2a6bcd73e3660e626e4a756d79a467bf 100644
--- a/drivers/net/wireless/mediatek/mt76/mt7996/mcu.c
+++ b/drivers/net/wireless/mediatek/mt76/mt7996/mcu.c
@@ -1977,51 +1977,74 @@  int mt7996_mcu_set_fixed_field(struct mt7996_dev *dev, struct mt7996_sta *msta,
 }
 
 static int
-mt7996_mcu_add_rate_ctrl_fixed(struct mt7996_dev *dev,
-			       struct ieee80211_link_sta *link_sta,
-			       struct mt7996_vif_link *link,
-			       struct mt7996_sta_link *msta_link,
-			       u8 link_id)
+mt7996_mcu_add_rate_ctrl_fixed(struct mt7996_dev *dev, struct mt7996_sta *msta,
+			       struct ieee80211_vif *vif, u8 link_id)
 {
-	struct cfg80211_chan_def *chandef = &link->phy->mt76->chandef;
-	struct cfg80211_bitrate_mask *mask = &link->bitrate_mask;
-	enum nl80211_band band = chandef->chan->band;
-	struct mt7996_sta *msta = msta_link->sta;
+	struct ieee80211_link_sta *link_sta;
+	struct cfg80211_bitrate_mask mask;
+	struct mt7996_sta_link *msta_link;
+	struct mt7996_vif_link *link;
 	struct sta_phy_uni phy = {};
-	int ret, nrates = 0;
+	struct ieee80211_sta *sta;
+	int ret, nrates = 0, idx;
+	enum nl80211_band band;
+	bool has_he;
 
 #define __sta_phy_bitrate_mask_check(_mcs, _gi, _ht, _he)			\
 	do {									\
-		u8 i, gi = mask->control[band]._gi;				\
+		u8 i, gi = mask.control[band]._gi;				\
 		gi = (_he) ? gi : gi == NL80211_TXRATE_FORCE_SGI;		\
 		phy.sgi = gi;							\
-		phy.he_ltf = mask->control[band].he_ltf;			\
-		for (i = 0; i < ARRAY_SIZE(mask->control[band]._mcs); i++) {	\
-			if (!mask->control[band]._mcs[i])			\
+		phy.he_ltf = mask.control[band].he_ltf;				\
+		for (i = 0; i < ARRAY_SIZE(mask.control[band]._mcs); i++) {	\
+			if (!mask.control[band]._mcs[i])			\
 				continue;					\
-			nrates += hweight16(mask->control[band]._mcs[i]);	\
-			phy.mcs = ffs(mask->control[band]._mcs[i]) - 1;		\
+			nrates += hweight16(mask.control[band]._mcs[i]);	\
+			phy.mcs = ffs(mask.control[band]._mcs[i]) - 1;		\
 			if (_ht)						\
 				phy.mcs += 8 * i;				\
 		}								\
 	} while (0)
 
-	if (link_sta->he_cap.has_he) {
+	rcu_read_lock();
+
+	link = mt7996_vif_link(dev, vif, link_id);
+	if (!link)
+		goto error_unlock;
+
+	msta_link = rcu_dereference(msta->link[link_id]);
+	if (!msta_link)
+		goto error_unlock;
+
+	sta = wcid_to_sta(&msta_link->wcid);
+	link_sta = rcu_dereference(sta->link[link_id]);
+	if (!link_sta)
+		goto error_unlock;
+
+	band = link->phy->mt76->chandef.chan->band;
+	has_he = link_sta->he_cap.has_he;
+	mask = link->bitrate_mask;
+	idx = msta_link->wcid.idx;
+
+	if (has_he) {
 		__sta_phy_bitrate_mask_check(he_mcs, he_gi, 0, 1);
 	} else if (link_sta->vht_cap.vht_supported) {
 		__sta_phy_bitrate_mask_check(vht_mcs, gi, 0, 0);
 	} else if (link_sta->ht_cap.ht_supported) {
 		__sta_phy_bitrate_mask_check(ht_mcs, gi, 1, 0);
 	} else {
-		nrates = hweight32(mask->control[band].legacy);
-		phy.mcs = ffs(mask->control[band].legacy) - 1;
+		nrates = hweight32(mask.control[band].legacy);
+		phy.mcs = ffs(mask.control[band].legacy) - 1;
 	}
+
+	rcu_read_unlock();
+
 #undef __sta_phy_bitrate_mask_check
 
 	/* fall back to auto rate control */
-	if (mask->control[band].gi == NL80211_TXRATE_DEFAULT_GI &&
-	    mask->control[band].he_gi == GENMASK(7, 0) &&
-	    mask->control[band].he_ltf == GENMASK(7, 0) &&
+	if (mask.control[band].gi == NL80211_TXRATE_DEFAULT_GI &&
+	    mask.control[band].he_gi == GENMASK(7, 0) &&
+	    mask.control[band].he_ltf == GENMASK(7, 0) &&
 	    nrates != 1)
 		return 0;
 
@@ -2034,16 +2057,16 @@  mt7996_mcu_add_rate_ctrl_fixed(struct mt7996_dev *dev,
 	}
 
 	/* fixed GI */
-	if (mask->control[band].gi != NL80211_TXRATE_DEFAULT_GI ||
-	    mask->control[band].he_gi != GENMASK(7, 0)) {
+	if (mask.control[band].gi != NL80211_TXRATE_DEFAULT_GI ||
+	    mask.control[band].he_gi != GENMASK(7, 0)) {
 		u32 addr;
 
 		/* firmware updates only TXCMD but doesn't take WTBL into
 		 * account, so driver should update here to reflect the
 		 * actual txrate hardware sends out.
 		 */
-		addr = mt7996_mac_wtbl_lmac_addr(dev, msta_link->wcid.idx, 7);
-		if (link_sta->he_cap.has_he)
+		addr = mt7996_mac_wtbl_lmac_addr(dev, idx, 7);
+		if (has_he)
 			mt76_rmw_field(dev, addr, GENMASK(31, 24), phy.sgi);
 		else
 			mt76_rmw_field(dev, addr, GENMASK(15, 12), phy.sgi);
@@ -2055,7 +2078,7 @@  mt7996_mcu_add_rate_ctrl_fixed(struct mt7996_dev *dev,
 	}
 
 	/* fixed HE_LTF */
-	if (mask->control[band].he_ltf != GENMASK(7, 0)) {
+	if (mask.control[band].he_ltf != GENMASK(7, 0)) {
 		ret = mt7996_mcu_set_fixed_field(dev, msta, &phy, link_id,
 						 RATE_PARAM_FIXED_HE_LTF);
 		if (ret)
@@ -2063,6 +2086,11 @@  mt7996_mcu_add_rate_ctrl_fixed(struct mt7996_dev *dev,
 	}
 
 	return 0;
+
+error_unlock:
+	rcu_read_unlock();
+
+	return -ENODEV;
 }
 
 static void
@@ -2181,6 +2209,7 @@  int mt7996_mcu_add_rate_ctrl(struct mt7996_dev *dev,
 			     struct mt7996_sta_link *msta_link,
 			     u8 link_id, bool changed)
 {
+	struct mt7996_sta *msta = msta_link->sta;
 	struct sk_buff *skb;
 	int ret;
 
@@ -2207,8 +2236,7 @@  int mt7996_mcu_add_rate_ctrl(struct mt7996_dev *dev,
 	if (ret)
 		return ret;
 
-	return mt7996_mcu_add_rate_ctrl_fixed(dev, link_sta, link, msta_link,
-					      link_id);
+	return mt7996_mcu_add_rate_ctrl_fixed(dev, msta, vif, link_id);
 }
 
 static int