2 votes

PySpark: compter sur une fenêtre avec réinitialisation

J'ai un DataFrame PySpark qui ressemble à ceci:

df = spark.createDataFrame(
    data=[
    (1, "ALLEMAGNE", "20230606", True),
    (2, "ALLEMAGNE", "20230620", False),
    (3, "ALLEMAGNE", "20230627", True),
    (4, "ALLEMAGNE", "20230705", True),
    (5, "ALLEMAGNE", "20230714", False),
    (6, "ALLEMAGNE", "20230715", True),
    ],
    schema=["ID", "PAYS", "DATE", "DRAPEAU"]
)
df.show()

+---+---------+--------+-----+
| ID|    PAYS|    DATE|DRAPEAU|
+---+---------+--------+-----+
|  1|ALLEMAGNE|20230606| true|
|  2|ALLEMAGNE|20230620|false|
|  3|ALLEMAGNE|20230627| true|
|  4|ALLEMAGNE|20230705| true|
|  5|ALLEMAGNE|20230714|false|
|  6|ALLEMAGNE|20230715| true|
+---+---------+--------+-----+

Le DataFrame a plus de pays. Je veux créer une nouvelle colonne COUNT_WITH_RESET en suivant la logique suivante:

  • Si DRAPEAU=False, alors COUNT_WITH_RESET=0.
  • Si DRAPEAU=True, alors COUNT_WITH_RESET doit compter le nombre de lignes à partir de la date précédente où DRAPEAU=False pour ce pays spécifique.

Voici le résultat pour l'exemple ci-dessus.

+---+---------+--------+-----+----------------+
| ID|    PAYS|    DATE|DRAPEAU|COUNT_WITH_RESET|
+---+---------+--------+-----+----------------+
|  1|ALLEMAGNE|20230606| true|               1|
|  2|ALLEMAGNE|20230620|false|               0|
|  3|ALLEMAGNE|20230627| true|               1|
|  4|ALLEMAGNE|20230705| true|               2|
|  5|ALLEMAGNE|20230714|false|               0|
|  6|ALLEMAGNE|20230715| true|               1|
+---+---------+--------+-----+----------------+

J'ai essayé avec row_number() sur une fenêtre mais je n'arrive pas à réinitialiser le comptage. J'ai également essayé avec .rowsBetween(Window.unboundedPreceding, Window.currentRow). Voici mon approche:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_reset = Window.partitionBy("PAYS").orderBy("DATE")

df_with_reset = (
    df
    .withColumn("COUNT_WITH_RESET", F.when(~F.col("DRAPEAU"), 0)
                .otherwise(F.row_number().over(window_reset)))
)

df_with_reset.show()

+---+---------+--------+-----+----------------+
| ID|    PAYS|    DATE|DRAPEAU|COUNT_WITH_RESET|
+---+---------+--------+-----+----------------+
|  1|ALLEMAGNE|20230606| true|               1|
|  2|ALLEMAGNE|20230620|false|               0|
|  3|ALLEMAGNE|20230627| true|               3|
|  4|ALLEMAGNE|20230705| true|               4|
|  5|ALLEMAGNE|20230714|false|               0|
|  6|ALLEMAGNE|20230715| true|               6|
+---+---------+--------+-----+----------------+

C'est clairement incorrect car ma fenêtre ne partitionne que par pays, mais suis-je sur la bonne voie? Y a-t-il une fonction intégrée spécifique dans PySpark pour réaliser cela? Ai-je besoin d'une UDF? Toute aide serait appréciée.

1voto

Shubham Sharma Points 39381

Partitionner le dataframe par COUNTRY puis calculer la somme cumulative sur la colonne inversée FLAG pour attribuer des numéros de groupe afin de distinguer entre différents blocs de lignes qui commencent par false

W1 = Window.partitionBy('COUNTRY').orderBy('DATE')
df1 = df.withColumn('blocks', F.sum((~F.col('FLAG')).cast('long')).over(W1))

df1.show()
# +---+-------+--------+-----+------+
# | ID|COUNTRY|    DATE| FLAG|blocks|
# +---+-------+--------+-----+------+
# |  1|GERMANY|20230606| true|     0|
# |  2|GERMANY|20230620|false|     1|
# |  3|GERMANY|20230627| true|     1|
# |  4|GERMANY|20230705| true|     1|
# |  5|GERMANY|20230714|false|     2|
# |  6|GERMANY|20230715| true|     2|
# +---+-------+--------+-----+------+

Partitionner le dataframe par COUNTRY avec blocs puis calculer le numéro de ligne sur la partition ordonnée pour créer un compteur séquentiel

W2 = Window.partitionBy('COUNTRY', 'blocks').orderBy('DATE')
df1 = df1.withColumn('COUNT_WITH_RESET', F.row_number().over(W2) - 1)

df1.show()
# +---+-------+--------+-----+------+----------------+
# | ID|COUNTRY|    DATE| FLAG|blocks|COUNT_WITH_RESET|
# +---+-------+--------+-----+------+----------------+
# |  1|GERMANY|20230606| true|     0|               0|
# |  2|GERMANY|20230620|false|     1|               0|
# |  3|GERMANY|20230627| true|     1|               1|
# |  4|GERMANY|20230705| true|     1|               2|
# |  5|GERMANY|20230714|false|     2|               0|
# |  6|GERMANY|20230715| true|     2|               1|
# +---+-------+--------+-----+------+----------------+

1voto

user238607 Points 683

J'ai modifié ma réponse d'ici https://stackoverflow.com/a/78056548/3238085 à ce problème.

import sys

from pyspark.sql import Window
from pyspark import SQLContext
from pyspark.sql.functions import *
import pyspark.sql.functions as F

spark = SparkContext('local')
sqlContext = SQLContext(spark)

sample_df = sqlContext.createDataFrame(
    data=[
        (1, "GERMANY", "20230606", True),
        (2, "GERMANY", "20230620", False),
        (3, "GERMANY", "20230627", True),
        (4, "GERMANY", "20230705", True),
        (5, "GERMANY", "20230714", False),
        (6, "GERMANY", "20230715", True),
    ],
    schema=["ID", "COUNTRY", "DATE", "FLAG"]
)

sample_df.show(100, truncate=False)

windowSpec = Window.partitionBy("COUNTRY").orderBy("id").rowsBetween(Window.unboundedPreceding, Window.currentRow)
sample_df = sample_df.withColumn('FLAGLIST', F.collect_list('FLAG').over(windowSpec))

initial_value = F.lit(0)

sample_df = sample_df.withColumn('COUNT_WITH_RESET', aggregate("FLAGLIST", initial_value,
                                                               lambda acc, x:  F.when( x == True,  acc + 1).otherwise(0)))

sample_df.show(truncate=False)

OUTPUT :

+---+-------+--------+-----+
|ID |COUNTRY|DATE    |FLAG |
+---+-------+--------+-----+
|1  |GERMANY|20230606|true |
|2  |GERMANY|20230620|false|
|3  |GERMANY|20230627|true |
|4  |GERMANY|20230705|true |
|5  |GERMANY|20230714|false|
|6  |GERMANY|20230715|true |
+---+-------+--------+-----+

+---+-------+--------+-----+--------------------------------------+----------------+
|ID |COUNTRY|DATE    |FLAG |FLAGLIST                              |COUNT_WITH_RESET|
+---+-------+--------+-----+--------------------------------------+----------------+
|1  |GERMANY|20230606|true |[true]                                |1               |
|2  |GERMANY|20230620|false|[true, false]                         |0               |
|3  |GERMANY|20230627|true |[true, false, true]                   |1               |
|4  |GERMANY|20230705|true |[true, false, true, true]             |2               |
|5  |GERMANY|20230714|false|[true, false, true, true, false]      |0               |
|6  |GERMANY|20230715|true |[true, false, true, true, false, true]|1               |
+---+-------+--------+-----+--------------------------------------+----------------+

Prograide.com

Prograide est une communauté de développeurs qui cherche à élargir la connaissance de la programmation au-delà de l'anglais.
Pour cela nous avons les plus grands doutes résolus en français et vous pouvez aussi poser vos propres questions ou résoudre celles des autres.

Powered by:

X