Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit d54bd00

Browse files
committed
remove unused code
1 parent 292b4e6 commit d54bd00

File tree

2 files changed

+3
-215
lines changed

2 files changed

+3
-215
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ class Float8DynamicLinear(torch.nn.Linear):
6262
def __init__(self, **super_kwargs):
6363
super().__init__(**super_kwargs)
6464

65-
def forward(self, x):
66-
x_fp8 = cast_to_float8_e4m3fn(
67-
x, self.forward_config, activation_scale=self.activation_scale
68-
)
65+
def forward(self, input: torch.Tensor) -> torch.Tensor:
66+
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
6967
if isinstance(self.weight, Float8Tensor): # cast by FSDP
7068
w_fp8 = self.weight
7169
else:
@@ -74,30 +72,6 @@ def forward(self, x):
7472
y = cast_to_float8_e5m2_bw(y, self.backward_config)
7573
return y
7674

77-
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
78-
"""This functions converts the weight to a Float8Tensor and sets its requires_grad to False.
79-
80-
Args:
81-
dtype: The dtype to quantize the weight to. Default is e4m3_dtype.
82-
83-
Note:
84-
This function is typically called during inference to quantize the weight once since
85-
the weight is not updated during inference.
86-
87-
"""
88-
assert not isinstance(
89-
self.weight, Float8Tensor
90-
), "Weight has already been quantized, cannot quantize again."
91-
scale = tensor_to_scale(self.weight, dtype)
92-
quantized_weight = to_fp8_no_autograd(
93-
self.weight,
94-
scale,
95-
dtype,
96-
self.forward_config,
97-
)
98-
self.weight = nn.Parameter(quantized_weight)
99-
self.weight.requires_grad = False
100-
10175
@classmethod
10276
def create_meta_class(
10377
cls, in_features: int, out_features: int
@@ -122,55 +96,37 @@ def set_weight_and_bias(
12296
self.bias = bias
12397
return self
12498

125-
def set_quantization_scales(
126-
self, pre_quantize_weight: bool, activation_scale: Optional[torch.Tensor] = None
127-
) -> "Float8DynamicLinear":
128-
if pre_quantize_weight:
129-
self.quantize_weight()
130-
131-
self.register_buffer("activation_scale", activation_scale)
132-
return self
133-
13499
@classmethod
135100
def from_float(
136101
cls,
137102
mod,
138103
emulate: bool = False,
139-
pre_quantize_weight: bool = False,
140-
activation_scale: Optional[torch.Tensor] = None,
141104
) -> "Float8DynamicLinear":
142105
"""
143106
Create an nn.Linear with fp8 compute from a regular nn.Linear
144107
145108
Args:
146109
mod (torch.nn.Linear): nn.Linear to convert
147110
emulate (bool): whether to emulate fp8 matmul logic in float32
148-
pre_quantize_weight (bool): whether to quantize the weight statically, this is useful
149-
for inference where weights are not updated.
150-
activation_scale (torch.Tensor): The scale of the input to this linear module, used for
151-
for inference when a statically known scale is available.
152111
"""
153112
return (
154113
cls.create_meta_class(mod.in_features, mod.out_features)
155114
.set_mm_configs(emulate)
156115
.set_weight_and_bias(mod.weight, mod.bias)
157-
.set_quantization_scales(pre_quantize_weight, activation_scale)
158116
)
159117

160118

161119
def cast_to_float8_e4m3fn(
162120
inpt_tensor: torch.Tensor,
163121
mm_config: ScaledMMConfig,
164122
reduce_amax: bool = False,
165-
activation_scale: Optional[torch.Tensor] = None,
166123
) -> Float8Tensor:
167124
"""Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
168125
169126
Args:
170127
inpt_tensor: The input tensor to be cast.
171128
mm_config: Configuration settings for the matrix multiplication
172129
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
173-
activation_scale: Optional tensor specifying the scale for activation. Default is None.
174130
175131
Returns:
176132
Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
@@ -180,11 +136,7 @@ def cast_to_float8_e4m3fn(
180136
"""
181137
if tensor_already_casted_to_fp8(inpt_tensor):
182138
return inpt_tensor
183-
scale = (
184-
activation_scale
185-
if activation_scale is not None
186-
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
187-
)
139+
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
188140
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
189141

190142

test/test_inference_flows.py

Lines changed: 0 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -47,170 +47,6 @@ def reset_parameters(self):
4747
m.reset_parameters()
4848

4949

50-
class TestHPTrainToFP8:
51-
def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor):
52-
with torch.no_grad():
53-
base_output = base_mlp(input_tensor)
54-
transformed_output = quantized_mlp(input_tensor)
55-
56-
# Compute and check SQNR
57-
sqnr = compute_error(base_output, transformed_output)
58-
assert sqnr.item() > 20, f"SQNR is too low: {sqnr.item()} dB"
59-
60-
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
61-
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
62-
@unittest.skipIf(
63-
not torch.cuda.is_available() or not is_H100,
64-
"CUDA not availabl or on non H100 machine",
65-
)
66-
def test_dynamic_fp8_mlp(self, compile_backend, dtype):
67-
original_mlp = FeedForward().to("cuda", dtype=dtype)
68-
original_mlp.reset_parameters()
69-
70-
dynamic_fp8_mlp = copy.deepcopy(original_mlp)
71-
swap_linear_with_float8_linear(
72-
dynamic_fp8_mlp,
73-
Float8DynamicLinear,
74-
from_float_kwargs={"pre_quantize_weight": True},
75-
)
76-
77-
batch_size = 4
78-
num_tokens = 1024
79-
embedding_dim = 4096
80-
81-
input_tensor = torch.randn(
82-
batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype
83-
)
84-
85-
# Compile the models
86-
compiled_original_mlp = torch.compile(original_mlp, backend=compile_backend)
87-
compiled_dynamic_fp8_mlp = torch.compile(
88-
dynamic_fp8_mlp, backend=compile_backend
89-
)
90-
91-
self.base_test_mlp_transform(
92-
compiled_original_mlp, compiled_dynamic_fp8_mlp, input_tensor
93-
)
94-
95-
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
96-
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
97-
@unittest.skipIf(
98-
not torch.cuda.is_available() or not is_H100,
99-
"CUDA not availabl or on non H100 machine",
100-
)
101-
def test_static_fp8_mlp(self, compile_backend, dtype):
102-
original_mlp = FeedForward().to("cuda", dtype=dtype)
103-
original_mlp.reset_parameters()
104-
105-
static_fp8_mlp = copy.deepcopy(original_mlp)
106-
swap_linear_with_float8_linear(
107-
static_fp8_mlp,
108-
Float8DynamicLinear,
109-
from_float_kwargs={
110-
"pre_quantize_weight": True,
111-
"activation_scale": torch.tensor(
112-
[1.0], device="cuda", dtype=torch.float32
113-
),
114-
},
115-
)
116-
117-
batch_size = 4
118-
num_tokens = 1024
119-
embedding_dim = 4096
120-
121-
input_tensor = torch.randn(
122-
batch_size, num_tokens, embedding_dim, device="cuda", dtype=dtype
123-
)
124-
125-
# Compile the models
126-
compiled_original_mlp = torch.compile(original_mlp, backend=compile_backend)
127-
compiled_static_fp8_mlp = torch.compile(static_fp8_mlp, backend=compile_backend)
128-
129-
self.base_test_mlp_transform(
130-
compiled_original_mlp, compiled_static_fp8_mlp, input_tensor
131-
)
132-
133-
134-
class TestFP8TrainToFP8:
135-
def train(self, model: nn.Module, dtype: torch.dtype):
136-
model.train()
137-
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
138-
criterion = nn.MSELoss()
139-
target_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
140-
for _ in range(10):
141-
input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
142-
optimizer.zero_grad()
143-
output = model(input_tensor)
144-
loss = criterion(output, target_tensor)
145-
loss.backward()
146-
optimizer.step()
147-
model.eval()
148-
return model
149-
150-
@pytest.mark.parametrize("compile_backend", ["eager", "inductor"])
151-
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
152-
@unittest.skipIf(
153-
not torch.cuda.is_available() or not is_H100,
154-
"CUDA not available or on non H100 machine",
155-
)
156-
def test_fp8_save_and_load(self, compile_backend: str, dtype: torch.dtype):
157-
# Initialize FP8 model
158-
fp8_mlp = FeedForward().to("cuda", dtype=torch.float32)
159-
fp8_mlp.reset_parameters()
160-
swap_linear_with_float8_linear(
161-
fp8_mlp,
162-
Float8DynamicLinear,
163-
)
164-
165-
# Train the model
166-
self.train(fp8_mlp, dtype)
167-
168-
# Generate input tensor and original out
169-
input_tensor = torch.randn(4, 1024, 4096, device="cuda", dtype=dtype)
170-
og_out = fp8_mlp(input_tensor)
171-
172-
# Save model state dict
173-
buffer = io.BytesIO()
174-
torch.save(fp8_mlp.state_dict(), buffer)
175-
176-
# Reset buffer position to the beginning
177-
buffer.seek(0)
178-
179-
# Later on you load the model, will be w/ Float8DynamicLinear on meta device
180-
with torch.device("meta"):
181-
new_fp8_mlp = FeedForward().to(dtype=dtype)
182-
swap_linear_with_float8_linear(
183-
new_fp8_mlp,
184-
Float8DynamicLinear,
185-
)
186-
187-
# Load the actual data
188-
new_fp8_mlp.load_state_dict(
189-
torch.load(buffer, weights_only=True), strict=True, assign=True
190-
)
191-
192-
# Dynamic Activations + Quantized Weights
193-
def quantize_dynamic_linear(x: nn.Module):
194-
if isinstance(x, Float8DynamicLinear):
195-
x.set_quantization_scales(pre_quantize_weight=True)
196-
return x
197-
198-
new_fp8_mlp.apply(quantize_dynamic_linear)
199-
200-
for module in new_fp8_mlp.modules():
201-
if isinstance(module, Float8DynamicLinear):
202-
assert isinstance(module.weight, Float8Tensor)
203-
assert module.weight.requires_grad is False
204-
205-
new_out = new_fp8_mlp(input_tensor)
206-
207-
# Assert exact equality
208-
assert torch.all(og_out == new_out).item()
209-
210-
211-
# WE ARE GOING TO KEEP 1 or the other BELOW IS THE SEPARATE MODULE WORKFLOW
212-
213-
21450
class TestHPTrainToFP8LinearInference:
21551
def base_test_mlp_transform(self, base_mlp, quantized_mlp, input_tensor):
21652
with torch.no_grad():

0 commit comments

Comments
 (0)