Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit d1eae9a

Browse files
committed
updates to enable static weight quantization/dynamic activation quantization
1 parent edae9a3 commit d1eae9a

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,30 @@ def forward(self, x):
7272
y = cast_to_float8_e5m2_bw(y, self.backward_config)
7373
return y
7474

75+
def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> None:
76+
"""Used to perform static_quantization, useful for inferenece where weights are not updated."""
77+
78+
scale = tensor_to_scale(self.weight, dtype)
79+
quantized_weight = to_fp8_no_autograd(
80+
self.weight,
81+
scale,
82+
dtype,
83+
self.forward_config,
84+
)
85+
self.weight = nn.Parameter(quantized_weight)
86+
7587
@classmethod
76-
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
88+
def from_float(
89+
cls, mod, emulate: bool = False, static_quantize_weight: bool = False
90+
) -> "Float8DynamicLinear":
7791
"""
7892
Create an nn.Linear with fp8 compute from a regular nn.Linear
7993
8094
Args:
8195
mod (torch.nn.Linear): nn.Linear to convert
8296
emulate (bool): whether to emulate fp8 matmul logic in float32
97+
static_quantize_weight (bool): whether to quantize the weight statically, this is useful
98+
for inference where weights are not updated.
8399
"""
84100
with torch.device("meta"):
85101
super_kwargs = {
@@ -96,6 +112,10 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
96112
)
97113
else:
98114
new_mod.weight = mod.weight
115+
116+
if static_quantize_weight:
117+
new_mod.static_quantize_weight()
118+
99119
new_mod.bias = mod.bias
100120
return new_mod
101121

float8_experimental/float8_linear_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import copy
77
import logging
88
from enum import auto, Enum
9-
from typing import Callable, List, Optional, Type
9+
from typing import Any, Callable, Dict, List, Optional, Type
1010

1111
import torch
1212
import torch.distributed as dist
@@ -100,6 +100,7 @@ def swap_linear_with_float8_linear(
100100
skip_fqn_list: Optional[List[str]] = None,
101101
emulate: bool = False,
102102
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
103+
from_float_kwargs: Dict[str, Any] = None,
103104
) -> nn.Module:
104105
"""
105106
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
@@ -114,6 +115,9 @@ def swap_linear_with_float8_linear(
114115
linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers
115116
that pass the filter function will be swapped.
116117
"""
118+
if from_float_kwargs is None:
119+
from_float_kwargs = {}
120+
117121
module_names_to_skip = set(skip_fqn_list or [])
118122
if isinstance(module, nn.Linear) and (
119123
linear_layer_filter is None or linear_layer_filter(module)
@@ -122,7 +126,7 @@ def swap_linear_with_float8_linear(
122126
raise AssertionError(
123127
f"Does not support a root nn.Linear with children: {module}"
124128
)
125-
return module_cls.from_float(module, emulate=emulate)
129+
return module_cls.from_float(module, emulate=emulate, **from_float_kwargs)
126130

127131
# Mark all modules to skip as visited
128132
root_module = module
@@ -146,7 +150,9 @@ def post_order_traversal(
146150
assert (
147151
parent_module is not None
148152
), f"Linear root module should return early: {module}"
149-
float8linear_module = module_cls.from_float(module, emulate=emulate)
153+
float8linear_module = module_cls.from_float(
154+
module, emulate=emulate, **from_float_kwargs
155+
)
150156
setattr(parent_module, module_name, float8linear_module)
151157

152158
post_order_traversal(root_module, "", None)

0 commit comments

Comments
 (0)