Skip to content

REF: lreshape, wide_to_long #55976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 23 additions & 35 deletions pandas/core/reshape/melt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pandas.core.dtypes.missing import notna

import pandas.core.algorithms as algos
from pandas.core.arrays import Categorical
from pandas.core.indexes.api import MultiIndex
from pandas.core.reshape.concat import concat
from pandas.core.reshape.util import tile_compat
Expand Down Expand Up @@ -139,7 +138,7 @@ def melt(
return result


def lreshape(data: DataFrame, groups, dropna: bool = True) -> DataFrame:
def lreshape(data: DataFrame, groups: dict, dropna: bool = True) -> DataFrame:
"""
Reshape wide-format data to long. Generalized inverse of DataFrame.pivot.

Expand Down Expand Up @@ -192,30 +191,20 @@ def lreshape(data: DataFrame, groups, dropna: bool = True) -> DataFrame:
2 Red Sox 2008 545
3 Yankees 2008 526
"""
if isinstance(groups, dict):
keys = list(groups.keys())
values = list(groups.values())
else:
keys, values = zip(*groups)

all_cols = list(set.union(*(set(x) for x in values)))
id_cols = list(data.columns.difference(all_cols))

K = len(values[0])

for seq in values:
if len(seq) != K:
raise ValueError("All column lists must be same length")

mdata = {}
pivot_cols = []

for target, names in zip(keys, values):
all_cols: set[Hashable] = set()
K = len(next(iter(groups.values())))
for target, names in groups.items():
if len(names) != K:
raise ValueError("All column lists must be same length")
to_concat = [data[col]._values for col in names]

mdata[target] = concat_compat(to_concat)
pivot_cols.append(target)
all_cols = all_cols.union(names)

id_cols = list(data.columns.difference(all_cols))
for col in id_cols:
mdata[col] = np.tile(data[col]._values, K)

Expand Down Expand Up @@ -467,10 +456,10 @@ def wide_to_long(
two 2.9
"""

def get_var_names(df, stub: str, sep: str, suffix: str) -> list[str]:
def get_var_names(df, stub: str, sep: str, suffix: str):
regex = rf"^{re.escape(stub)}{re.escape(sep)}{suffix}$"
pattern = re.compile(regex)
return [col for col in df.columns if pattern.match(col)]
return df.columns[df.columns.str.match(pattern)]

def melt_stub(df, stub: str, i, j, value_vars, sep: str):
newdf = melt(
Expand All @@ -480,7 +469,6 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
value_name=stub.rstrip(sep),
var_name=j,
)
newdf[j] = Categorical(newdf[j])
newdf[j] = newdf[j].str.replace(re.escape(stub + sep), "", regex=True)

# GH17627 Cast numerics suffixes to int/float
Expand All @@ -497,7 +485,7 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
else:
stubnames = list(stubnames)

if any(col in stubnames for col in df.columns):
if df.columns.isin(stubnames).any():
raise ValueError("stubname can't be identical to a column name")

if not is_list_like(i):
Expand All @@ -508,18 +496,18 @@ def melt_stub(df, stub: str, i, j, value_vars, sep: str):
if df[i].duplicated().any():
raise ValueError("the id variables need to uniquely identify each row")

value_vars = [get_var_names(df, stub, sep, suffix) for stub in stubnames]

value_vars_flattened = [e for sublist in value_vars for e in sublist]
id_vars = list(set(df.columns.tolist()).difference(value_vars_flattened))
_melted = []
value_vars_flattened = []
for stub in stubnames:
value_var = get_var_names(df, stub, sep, suffix)
value_vars_flattened.extend(value_var)
_melted.append(melt_stub(df, stub, i, j, value_var, sep))

_melted = [melt_stub(df, s, i, j, v, sep) for s, v in zip(stubnames, value_vars)]
melted = _melted[0].join(_melted[1:], how="outer")
melted = concat(_melted, axis=1)
id_vars = df.columns.difference(value_vars_flattened)
new = df[id_vars]

if len(i) == 1:
new = df[id_vars].set_index(i).join(melted)
return new

new = df[id_vars].merge(melted.reset_index(), on=i).set_index(i + [j])

return new
return new.set_index(i).join(melted)
else:
return new.merge(melted.reset_index(), on=i).set_index(i + [j])