Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 9bee8a3

Browse files
committed
updates to enable static weight quantization/dynamic activation quantization
1 parent edae9a3 commit 9bee8a3

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,30 @@ def forward(self, x):
7272
y = cast_to_float8_e5m2_bw(y, self.backward_config)
7373
return y
7474

75+
def static_quantize_weight(self, dtype: torch.dtype = torch.float8_e4m3fn) -> None:
76+
"""Used to perform static_quantization, useful for inferenece where weights are not updated."""
77+
78+
scale = tensor_to_scale(self.weight, dtype)
79+
quantized_weight = to_fp8_no_autograd(
80+
self.weight,
81+
scale,
82+
dtype,
83+
self.forward_config,
84+
)
85+
self.weight = nn.Parameter(quantized_weight)
86+
7587
@classmethod
76-
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
88+
def from_float(
89+
cls, mod, emulate: bool = False, static_quantize_weight: bool = False
90+
) -> "Float8DynamicLinear":
7791
"""
7892
Create an nn.Linear with fp8 compute from a regular nn.Linear
7993
8094
Args:
8195
mod (torch.nn.Linear): nn.Linear to convert
8296
emulate (bool): whether to emulate fp8 matmul logic in float32
97+
static_quantize_weight (bool): whether to quantize the weight statically, this is useful
98+
for inference where weights are not updated.
8399
"""
84100
with torch.device("meta"):
85101
super_kwargs = {
@@ -96,6 +112,10 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
96112
)
97113
else:
98114
new_mod.weight = mod.weight
115+
116+
if static_quantize_weight:
117+
new_mod.static_quantize_weight()
118+
99119
new_mod.bias = mod.bias
100120
return new_mod
101121

float8_experimental/float8_linear_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import copy
77
import logging
88
from enum import auto, Enum
9-
from typing import Callable, List, Optional, Type
9+
from typing import Any, Callable, Dict, List, Optional, Type
1010

1111
import torch
1212
import torch.distributed as dist
@@ -100,6 +100,7 @@ def swap_linear_with_float8_linear(
100100
skip_fqn_list: Optional[List[str]] = None,
101101
emulate: bool = False,
102102
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
103+
from_float_kwargs: Dict[str, Any] = None,
103104
) -> nn.Module:
104105
"""
105106
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
@@ -122,7 +123,7 @@ def swap_linear_with_float8_linear(
122123
raise AssertionError(
123124
f"Does not support a root nn.Linear with children: {module}"
124125
)
125-
return module_cls.from_float(module, emulate=emulate)
126+
return module_cls.from_float(module, emulate=emulate, **from_float_kwargs)
126127

127128
# Mark all modules to skip as visited
128129
root_module = module
@@ -146,7 +147,9 @@ def post_order_traversal(
146147
assert (
147148
parent_module is not None
148149
), f"Linear root module should return early: {module}"
149-
float8linear_module = module_cls.from_float(module, emulate=emulate)
150+
float8linear_module = module_cls.from_float(
151+
module, emulate=emulate, **from_float_kwargs
152+
)
150153
setattr(parent_module, module_name, float8linear_module)
151154

152155
post_order_traversal(root_module, "", None)

0 commit comments

Comments
 (0)