J'ai un DataFrame PySpark qui ressemble à ceci:
df = spark.createDataFrame(
data=[
(1, "ALLEMAGNE", "20230606", True),
(2, "ALLEMAGNE", "20230620", False),
(3, "ALLEMAGNE", "20230627", True),
(4, "ALLEMAGNE", "20230705", True),
(5, "ALLEMAGNE", "20230714", False),
(6, "ALLEMAGNE", "20230715", True),
],
schema=["ID", "PAYS", "DATE", "DRAPEAU"]
)
df.show()
+---+---------+--------+-----+
| ID| PAYS| DATE|DRAPEAU|
+---+---------+--------+-----+
| 1|ALLEMAGNE|20230606| true|
| 2|ALLEMAGNE|20230620|false|
| 3|ALLEMAGNE|20230627| true|
| 4|ALLEMAGNE|20230705| true|
| 5|ALLEMAGNE|20230714|false|
| 6|ALLEMAGNE|20230715| true|
+---+---------+--------+-----+
Le DataFrame a plus de pays. Je veux créer une nouvelle colonne COUNT_WITH_RESET
en suivant la logique suivante:
- Si
DRAPEAU=False
, alorsCOUNT_WITH_RESET=0
. - Si
DRAPEAU=True
, alorsCOUNT_WITH_RESET
doit compter le nombre de lignes à partir de la date précédente oùDRAPEAU=False
pour ce pays spécifique.
Voici le résultat pour l'exemple ci-dessus.
+---+---------+--------+-----+----------------+
| ID| PAYS| DATE|DRAPEAU|COUNT_WITH_RESET|
+---+---------+--------+-----+----------------+
| 1|ALLEMAGNE|20230606| true| 1|
| 2|ALLEMAGNE|20230620|false| 0|
| 3|ALLEMAGNE|20230627| true| 1|
| 4|ALLEMAGNE|20230705| true| 2|
| 5|ALLEMAGNE|20230714|false| 0|
| 6|ALLEMAGNE|20230715| true| 1|
+---+---------+--------+-----+----------------+
J'ai essayé avec row_number()
sur une fenêtre mais je n'arrive pas à réinitialiser le comptage. J'ai également essayé avec .rowsBetween(Window.unboundedPreceding, Window.currentRow)
. Voici mon approche:
from pyspark.sql.window import Window
import pyspark.sql.functions as F
window_reset = Window.partitionBy("PAYS").orderBy("DATE")
df_with_reset = (
df
.withColumn("COUNT_WITH_RESET", F.when(~F.col("DRAPEAU"), 0)
.otherwise(F.row_number().over(window_reset)))
)
df_with_reset.show()
+---+---------+--------+-----+----------------+
| ID| PAYS| DATE|DRAPEAU|COUNT_WITH_RESET|
+---+---------+--------+-----+----------------+
| 1|ALLEMAGNE|20230606| true| 1|
| 2|ALLEMAGNE|20230620|false| 0|
| 3|ALLEMAGNE|20230627| true| 3|
| 4|ALLEMAGNE|20230705| true| 4|
| 5|ALLEMAGNE|20230714|false| 0|
| 6|ALLEMAGNE|20230715| true| 6|
+---+---------+--------+-----+----------------+
C'est clairement incorrect car ma fenêtre ne partitionne que par pays, mais suis-je sur la bonne voie? Y a-t-il une fonction intégrée spécifique dans PySpark pour réaliser cela? Ai-je besoin d'une UDF? Toute aide serait appréciée.