2015-11-13 22 views
15

Chcę filtrować DataFrame przy użyciu warunku związanego z długością kolumny, to pytanie może być bardzo proste, ale nie znalazłem żadnego powiązanego pytania w SO.Filtrowanie DataFrame przy użyciu długości kolumny

dokładniej, mam DataFrame tylko jeden Column które z ArrayType(StringType()) chcę filtrować DataFrame pomocą długość jak filterer, kręciłem fragment poniżej.

df = sqlContext.read.parquet("letters.parquet") 
df.show() 

# The output will be 
# +------------+ 
# |  tokens| 
# +------------+ 
# |[L, S, Y, S]| 
# |[L, V, I, S]| 
# |[I, A, N, A]| 
# |[I, L, S, A]| 
# |[E, N, N, Y]| 
# |[E, I, M, A]| 
# |[O, A, N, A]| 
# | [S, U, S]| 
# +------------+ 

# But I want only the entries with length 3 or less 
fdf = df.filter(len(df.tokens) <= 3) 
fdf.show() # But it says that the TypeError: object of type 'Column' has no len(), so the previous statement is obviously incorrect. 

Column's Documentation czytałem, ale nie znaleźliśmy żadnego majątku przydatna dla sprawy. Doceniam każdą pomoc!

Odpowiedz

29

W Spark> = 1,5 można użyć size funkcję:

from pyspark.sql.functions import col, size 

df = sqlContext.createDataFrame([ 
    (["L", "S", "Y", "S"], ), 
    (["L", "V", "I", "S"], ), 
    (["I", "A", "N", "A"], ), 
    (["I", "L", "S", "A"], ), 
    (["E", "N", "N", "Y"], ), 
    (["E", "I", "M", "A"], ), 
    (["O", "A", "N", "A"], ), 
    (["S", "U", "S"], )], 
    ("tokens",)) 

df.where(size(col("tokens")) <= 3).show() 

## +---------+ 
## | tokens| 
## +---------+ 
## |[S, U, S]| 
## +---------+ 

W Spark < 1,5 UDF powinno załatwić sprawę:

from pyspark.sql.types import IntegerType 
from pyspark.sql.functions import udf 

size_ = udf(lambda xs: len(xs), IntegerType()) 

df.where(size_(col("tokens")) <= 3).show() 

## +---------+ 
## | tokens| 
## +---------+ 
## |[S, U, S]| 
## +---------+ 

Jeśli używasz HiveContext następnie size UDF z surowego SQL powinien działać z dowolną wersją:

df.registerTempTable("df") 
sqlContext.sql("SELECT * FROM df WHERE size(tokens) <= 3").show() 

## +--------------------+ 
## |    tokens| 
## +--------------------+ 
## |ArrayBuffer(S, U, S)| 
## +--------------------+ 

przypadku kolumn smyczkowych można albo użyć udf określono powyżej lub length funkcja:

from pyspark.sql.functions import length 

df = sqlContext.createDataFrame([("fooo",), ("bar",)], ("k",)) 
df.where(length(col("k")) <= 3).show() 

## +---+ 
## | k| 
## +---+ 
## |bar| 
## +---+ 
+2

Co jeśli kolumna jest 'string' i udaję filtrować według długości' string' za? –

+3

Ta sama funkcja udf lub 'length'. – zero323

Powiązane problemy