Skip to content

Fuse matrix multiplication + SiLU #5413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

JohannesGaessler
Copy link
Collaborator

Currently I think there is a lot of potential for performance improvements from fusing tensors together since this would reduce the amount of I/O needed. I'm interested in working on this but I want to determine the implementation details in advance in order to minimize the time spent on refactoring. I think there are two different approaches to do this: the user code could either explicitly opt in by building the graph with some fused kernel or the graph could be optimized afterwards by replacing certain tensor sequences with fused kernels. In this PR I tried a simple graph optimization implementation for fusing matrix multiplication and SiLU but that is already relatively complicated because you have issues with the graph size changing afterwards and branching (branching not considered for this PR). It's possible that I'm simply ignorant with regards to the best way to implement graph optimization but so far my impression is that an explicit opt-in would be preferable. @ggerganov @slaren your feedback would be appreciated.

@slaren
Copy link
Member

slaren commented Feb 8, 2024

Some notes:

  • When fusing multiple ops, it is important to check that there are no other ops that depend on the output of the first ops. Currently, the graph nodes have a list of parents (ie. the tensor src), but not of children, so this needs to be generated before removing nodes from the graph. ggml-alloc also needs a count of children of each node, which it generates before doing anything else. It would be good to do this in a common pass to avoid repeating the same task multiple times.
  • Modifying the graph is better done on the graph itself, before the list of nodes is generated (which is just one of many possible topological sorts of the graph). But I am not sure that there is a good way to do this at the moment in ggml, because the graphs are intimately linked to their topological sort as a means to avoid issues when dependencies are not properly represented in the graph (eg. the calls to ggml_build_forward_expand after the KV update).
  • This should be done in cooperation with the backends. The backends should not be required to implement fused operations for their basic operation, as doing so would increase the complexity of adding new backends significantly, and furthermore it is simply not practical to have to update every backend every time something like this is changed. There is already an exception to this in ggml_soft_max_ext, which may have been ok at the time when we only had to update 3 backends, but now we have several more, and in time even more will be added, this kind of backwards breaking changes to the ops are untenable.

So I think that addressing this properly will require a lot of changes to ggml. For now, it may be ok to hack this in the graph_compute of each backend.

@ggerganov
Copy link
Member

Currently I think there is a lot of potential for performance improvements from fusing tensors together since this would reduce the amount of I/O needed.

The most significant gain from fusing would be in the attention. There is already a PR with the necessary framework for fusing setup: #5021

There is already an exception to this in ggml_soft_max_ext

Is there any other way of fusing ops, other than the way we used for ggml_soft_max_ext? I think there is no way around adding explicit fused implementations in all backends. Technically, ggml_soft_max is now obsolete and should be removed since ggml_soft_max_ext is a superset of the functionality.

@slaren
Copy link
Member

slaren commented Feb 9, 2024

Is there any other way of fusing ops, other than the way we used for ggml_soft_max_ext?

The alternative would be to implement fusing automatically in the backends without changing the ops. In that way, implementing fused ops would be an optional optimization for the backends. At the time ggml_soft_max_ext was added, that was the simplest way to do it, but as more backends are added, changing the behaviour of the ops becomes a lot more expensive and having a better mechanism to do it becomes more important. I think it would be good to have a backend-specific graph optimization step, where optimizations such as fusing ops, or re-implementing ops in terms of different ops (like converting convolutions to im2col) are implemented, but that would require a lot of changes.

@JohannesGaessler
Copy link
Collaborator Author

The most significant gain from fusing would be in the attention.

I agree, but Steward (FSSRepo) is already working on this and I think onboarding more devs is good for the project long-term. So if possible I would prefer to let him handle it and gain experience; if the performance is suboptimal I can still take a look afterwards.

Other than FlashAttention there should also be performance gains from fusing branching matrix multiplications (in LLaMA). Firstly from fusing the 3 pre-attention KQV matrix multiplications as well as the following RoPE into a single operation. Secondly from fusing the up and gate matrix multiplications as well as the SiLU and multiplication.

There is already a PR with the necessary framework for fusing setup: #5021

Are we talking about the same thing? I want to work out the best way of handling fused operations going forward. In particular when it comes to how to handle them across different backends. Unless I'm misunderstanding something in that PR there is currently a simple define to enable/disable FlashAttention which (I assume) would not work for e.g. Vulkan since there is no implementation.

I think there is no way around adding explicit fused implementations in all backends.

You could provide more information during the build phase e.g. by extending ggml_context regarding which backend the graph is being built for. For backends without the fused operations you could then replace them with a series of non-fused operations. But this would make the ggml graphs non-portable.

Or, as slaren said, you could add a backend-specific optimization step prior to execution where non-fused operations are replaced with their fused variants if possible.

but as more backends are added, changing the behaviour of the ops becomes a lot more expensive and having a better mechanism to do it becomes more important.

Are there plans to allow the combination of different backends, and if so, what is the ETA? I plan to implement fused operations for CPU and CUDA. For now we could maybe make it so that fused operations are only used over their non-fused counterparts if none of the other backends are used. Specifically, I would add #ifdef checks to the build code so that fused operations are disabled if e.g. GGML_USE_VULKAN is defined. That should guarantee that all backends continue to work even without any changes but at the cost that other backends do not receive any benefits from changes to the CPU code.

@ggerganov
Copy link
Member

Yes, slaren's suggestion for automatic backend-specific fusing should be the long-term goal. #5021 can be viewed as another exception due to the importance of having faster and memory-efficient attention. Currently, it is a compile-time define because it's simpler this way, but the FA op could be used based on runtime parameters depending on the backends used

I plan to implement fused operations for CPU and CUDA.

I'm afraid the gains from this might not justify the increase in technical debt. Would recommend to attempt this at a later stage, likely when we have some variant of the graph optimization step mentioned earlier

@JohannesGaessler
Copy link
Collaborator Author

I'm afraid the gains from this might not justify the increase in technical debt. Would recommend to attempt this at a later stage, likely when we have some variant of the graph optimization step mentioned earlier

No worries, I'm willing to implement this. But since this adds a certain level of complexity I wanted to first discuss possible alternatives.

@slaren
Copy link
Member

slaren commented Feb 9, 2024

You could provide more information during the build phase e.g. by extending ggml_context regarding which backend the graph is being built for.

This wouldn't work with ggml_backend_sched, since it is not known what backend will be used for a specific op until the graph is split.

Are there plans to allow the combination of different backends, and if so, what is the ETA?

It is something that I would like to add after the logic for automatic offloading of large matrix multiplications is moved to ggml_backend_sched and the remaining hooks in ggml.c of the CUDA and Vulkan backends are removed, but it is not a priority.

@mofosyne mofosyne added Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level refactoring Refactoring performance Speed related topics labels May 10, 2024
@JohannesGaessler
Copy link
Collaborator Author

@mofosyne was the merging of master into this branch automatic or did you do it manually? In either case, please don't do this.

@mofosyne
Copy link
Collaborator

It was done manually as the conflict appears to be minor and would like to make it easier to keep PRs going.
But alright, I'll avoid doing so. Especially for Draft PR.

(Or as a general rule leave it to the PR authors? As they would generally have the full context?)

@JohannesGaessler
Copy link
Collaborator Author

I personally don't want other people to make any changes to one of my PRs without prior coordination or good reason. I use my own llama.cpp fork on Github to share code between my machines so if there are changes that I am unaware of that can potentially cause me to lose a lot of time debugging.

Copy link
Contributor

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 551 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8546.41ms p(95)=19914.29ms fails=, finish reason: stop=487 truncated=64
  • Prompt processing (pp): avg=102.51tk/s p(95)=475.74tk/s
  • Token generation (tg): avg=32.45tk/s p(95)=47.73tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=activation-fusion commit=18d452a863a325e75ff5cf96dad43af6fdc21338

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 551 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715416820 --> 1715417450
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 304.46, 304.46, 304.46, 304.46, 304.46, 682.77, 682.77, 682.77, 682.77, 682.77, 611.46, 611.46, 611.46, 611.46, 611.46, 654.99, 654.99, 654.99, 654.99, 654.99, 739.2, 739.2, 739.2, 739.2, 739.2, 740.13, 740.13, 740.13, 740.13, 740.13, 740.16, 740.16, 740.16, 740.16, 740.16, 758.78, 758.78, 758.78, 758.78, 758.78, 772.51, 772.51, 772.51, 772.51, 772.51, 790.24, 790.24, 790.24, 790.24, 790.24, 789.95, 789.95, 789.95, 789.95, 789.95, 808.77, 808.77, 808.77, 808.77, 808.77, 819.28, 819.28, 819.28, 819.28, 819.28, 846.61, 846.61, 846.61, 846.61, 846.61, 737.55, 737.55, 737.55, 737.55, 737.55, 738.76, 738.76, 738.76, 738.76, 738.76, 739.25, 739.25, 739.25, 739.25, 739.25, 765.5, 765.5, 765.5, 765.5, 765.5, 763.47, 763.47, 763.47, 763.47, 763.47, 766.7, 766.7, 766.7, 766.7, 766.7, 774.39, 774.39, 774.39, 774.39, 774.39, 778.85, 778.85, 778.85, 778.85, 778.85, 792.18, 792.18, 792.18, 792.18, 792.18, 768.5, 768.5, 768.5, 768.5, 768.5, 770.1, 770.1, 770.1, 770.1, 770.1, 770.26, 770.26, 770.26, 770.26, 770.26, 786.06, 786.06, 786.06, 786.06, 786.06, 786.12, 786.12, 786.12, 786.12, 786.12, 785.51, 785.51, 785.51, 785.51, 785.51, 784.32, 784.32, 784.32, 784.32, 784.32, 789.88, 789.88, 789.88, 789.88, 789.88, 788.43, 788.43, 788.43, 788.43, 788.43, 793.44, 793.44, 793.44, 793.44, 793.44, 804.7, 804.7, 804.7, 804.7, 804.7, 812.43, 812.43, 812.43, 812.43, 812.43, 822.82, 822.82, 822.82, 822.82, 822.82, 820.31, 820.31, 820.31, 820.31, 820.31, 818.44, 818.44, 818.44, 818.44, 818.44, 821.07, 821.07, 821.07, 821.07, 821.07, 822.36, 822.36, 822.36, 822.36, 822.36, 822.19, 822.19, 822.19, 822.19, 822.19, 832.34, 832.34, 832.34, 832.34, 832.34, 812.29, 812.29, 812.29, 812.29, 812.29, 806.59, 806.59, 806.59, 806.59, 806.59, 804.8, 804.8, 804.8, 804.8, 804.8, 803.26, 803.26, 803.26, 803.26, 803.26, 803.01, 803.01, 803.01, 803.01, 803.01, 805.84, 805.84, 805.84, 805.84, 805.84, 808.7, 808.7, 808.7, 808.7, 808.7, 811.38, 811.38, 811.38, 811.38, 811.38, 816.7, 816.7, 816.7, 816.7, 816.7, 820.65, 820.65, 820.65, 820.65, 820.65, 820.16, 820.16, 820.16, 820.16, 820.16, 826.56, 826.56, 826.56, 826.56, 826.56, 827.68, 827.68, 827.68, 827.68, 827.68, 826.42, 826.42, 826.42, 826.42, 826.42, 827.09, 827.09, 827.09, 827.09, 827.09, 826.62, 826.62, 826.62, 826.62, 826.62, 828.43, 828.43, 828.43, 828.43, 828.43, 831.48, 831.48, 831.48, 831.48, 831.48, 831.75, 831.75, 831.75, 831.75, 831.75, 831.75]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 551 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715416820 --> 1715417450
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 42.04, 42.04, 42.04, 42.04, 42.04, 38.04, 38.04, 38.04, 38.04, 38.04, 30.79, 30.79, 30.79, 30.79, 30.79, 32.85, 32.85, 32.85, 32.85, 32.85, 33.34, 33.34, 33.34, 33.34, 33.34, 33.22, 33.22, 33.22, 33.22, 33.22, 33.68, 33.68, 33.68, 33.68, 33.68, 34.02, 34.02, 34.02, 34.02, 34.02, 34.56, 34.56, 34.56, 34.56, 34.56, 34.65, 34.65, 34.65, 34.65, 34.65, 34.47, 34.47, 34.47, 34.47, 34.47, 34.23, 34.23, 34.23, 34.23, 34.23, 33.5, 33.5, 33.5, 33.5, 33.5, 32.83, 32.83, 32.83, 32.83, 32.83, 32.6, 32.6, 32.6, 32.6, 32.6, 32.26, 32.26, 32.26, 32.26, 32.26, 32.45, 32.45, 32.45, 32.45, 32.45, 32.26, 32.26, 32.26, 32.26, 32.26, 31.75, 31.75, 31.75, 31.75, 31.75, 31.55, 31.55, 31.55, 31.55, 31.55, 31.5, 31.5, 31.5, 31.5, 31.5, 31.55, 31.55, 31.55, 31.55, 31.55, 31.72, 31.72, 31.72, 31.72, 31.72, 31.63, 31.63, 31.63, 31.63, 31.63, 31.75, 31.75, 31.75, 31.75, 31.75, 31.77, 31.77, 31.77, 31.77, 31.77, 31.86, 31.86, 31.86, 31.86, 31.86, 31.52, 31.52, 31.52, 31.52, 31.52, 31.39, 31.39, 31.39, 31.39, 31.39, 31.63, 31.63, 31.63, 31.63, 31.63, 31.78, 31.78, 31.78, 31.78, 31.78, 31.84, 31.84, 31.84, 31.84, 31.84, 31.99, 31.99, 31.99, 31.99, 31.99, 32.0, 32.0, 32.0, 32.0, 32.0, 31.73, 31.73, 31.73, 31.73, 31.73, 31.72, 31.72, 31.72, 31.72, 31.72, 31.5, 31.5, 31.5, 31.5, 31.5, 31.44, 31.44, 31.44, 31.44, 31.44, 31.68, 31.68, 31.68, 31.68, 31.68, 31.77, 31.77, 31.77, 31.77, 31.77, 31.78, 31.78, 31.78, 31.78, 31.78, 31.72, 31.72, 31.72, 31.72, 31.72, 31.72, 31.72, 31.72, 31.72, 31.72, 31.03, 31.03, 31.03, 31.03, 31.03, 30.91, 30.91, 30.91, 30.91, 30.91, 29.78, 29.78, 29.78, 29.78, 29.78, 29.56, 29.56, 29.56, 29.56, 29.56, 29.61, 29.61, 29.61, 29.61, 29.61, 29.81, 29.81, 29.81, 29.81, 29.81, 29.81, 29.81, 29.81, 29.81, 29.81, 30.02, 30.02, 30.02, 30.02, 30.02, 30.01, 30.01, 30.01, 30.01, 30.01, 29.9, 29.9, 29.9, 29.9, 29.9, 29.85, 29.85, 29.85, 29.85, 29.85, 29.8, 29.8, 29.8, 29.8, 29.8, 29.95, 29.95, 29.95, 29.95, 29.95, 30.05, 30.05, 30.05, 30.05, 30.05, 30.16, 30.16, 30.16, 30.16, 30.16, 30.21, 30.21, 30.21, 30.21, 30.21, 30.24, 30.24, 30.24, 30.24, 30.24, 30.24, 30.24, 30.24, 30.24, 30.24, 30.26]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 551 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715416820 --> 1715417450
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11, 0.11, 0.11, 0.11, 0.11, 0.4, 0.4, 0.4, 0.4, 0.4, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.11, 0.11, 0.11, 0.11, 0.11, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.24, 0.25, 0.25, 0.25, 0.25, 0.25, 0.16, 0.16, 0.16, 0.16, 0.16, 0.18, 0.18, 0.18, 0.18, 0.18, 0.24, 0.24, 0.24, 0.24, 0.24, 0.27, 0.27, 0.27, 0.27, 0.27, 0.11, 0.11, 0.11, 0.11, 0.11, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.33, 0.33, 0.33, 0.33, 0.33, 0.11, 0.11, 0.11, 0.11, 0.11, 0.18, 0.18, 0.18, 0.18, 0.18, 0.27, 0.27, 0.27, 0.27, 0.27, 0.2, 0.2, 0.2, 0.2, 0.2, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.32, 0.32, 0.32, 0.32, 0.32, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.16, 0.1, 0.1, 0.1, 0.1, 0.1, 0.13, 0.13, 0.13, 0.13, 0.13, 0.36, 0.36, 0.36, 0.36, 0.36, 0.52, 0.52, 0.52, 0.52, 0.52, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.51, 0.51, 0.51, 0.51, 0.51, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.05, 0.05, 0.05, 0.05, 0.05, 0.19, 0.19, 0.19, 0.19, 0.19, 0.26, 0.26, 0.26, 0.26, 0.26, 0.29, 0.29, 0.29, 0.29, 0.29, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.09, 0.09, 0.09, 0.09, 0.09, 0.14, 0.14, 0.14, 0.14, 0.14, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.15, 0.15, 0.15, 0.15, 0.15, 0.25]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 551 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715416820 --> 1715417450
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0]
                    
Loading

@slaren
Copy link
Member

slaren commented Jun 26, 2024

I think that the simplest way to support this would be to forget about trying to make the operation fusion automatic and just add a bunch of fused ops to ggml with helper functions that can decompose the fused op into its individual ops for cases where the backend does not implement the fused op. Then backends could implement support as such:

// node was created with ggml_mul_mat_bias(ctx, a, b, c) = ggml_add(ggml_mul_mat(a, b), c)

case GGML_OP_FUSED:
    if (ggml_fused_n_ops(node) == 2 && 
        ggml_fused_op(node, 0) == GGML_OP_MUL_MAT &&
        ggml_fused_op(node, 1) == GGML_OP_ADD) { 
        ggml_cuda_mul_mat_bias(ctx, node);
    } else {
        // fused op not supported, decompose into the individual ops and compute
        for (int i = 0; i < ggml_fused_n_ops(node); i++) {
            ggml_tensor op = ggml_unfuse(node, i);
            ggml_cuda_compute_forward(ctx, &op);
        }
    }
    break;

@ggerganov
Copy link
Member

That sounds like a good approach. Do you think this same approach can be applied to the existing convolution operators (related to this ggml-org/ggml#873)? I.e. the im2col + mul_mat is the "unfused path" and each backend could optionally implement a more efficient "fused" convolutions

@ggerganov ggerganov closed this Jun 29, 2024
@ggerganov ggerganov reopened this Jun 29, 2024
@JohannesGaessler
Copy link
Collaborator Author

When implementing fused ops, keep in mind that if the I/O scales linearly with batch size then you only really get a performance improvement for small batch sizes so that is what the focus should be. (At the same time it is the kernels for large batch sizes that contribute the most to compilation time/binary size.)

For simple fused ops in the context of llama.cpp I think making them opt-in is fine. Long-term one of my goals is to enable training via ggml though so that would require a user writing training code for a model architecture to know those fused ops. For the more complicated cases like batching multiple matrix multiplications that will probably get relatively tricky (though with the stream.k MMQ implementation that has become much less important). Automatic optimization of ggml graph is not mutually exclusive with opt-in fused operations though; it would definitely be possible to implement the fused ops and only implement automatic optimization at a later point.

@slaren
Copy link
Member

slaren commented Jun 29, 2024

im2col + mul_mat is the "unfused path" and each backend could optionally implement a more efficient "fused" convolutions

The approach would work well for operations such as mul_mat + activation + bias + scale.. etc, where the additional operations can be run inplace, since that doesn't require allocating new tensors. I don't think that this would work very well in the im2col case because the unfused path requires more memory than the fused path, since im2col requires allocating an intermediate tensor to store the result. To be able to do this efficiently, the selection between fused and unfused path would need to be done before the graph allocation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics refactoring Refactoring Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants